diff options
Diffstat (limited to 'tests/test_protocol.py')
-rw-r--r-- | tests/test_protocol.py | 660 |
1 files changed, 660 insertions, 0 deletions
diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..63b689f --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,660 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import io +import json +from concurrent.futures import Future +from pathlib import Path +from typing import Optional +from unittest.mock import Mock + +import attrs +import pytest + +from pygls.exceptions import JsonRpcException, JsonRpcInvalidParams +from lsprotocol.types import ( + PROGRESS, + TEXT_DOCUMENT_COMPLETION, + ClientCapabilities, + CompletionItem, + CompletionItemKind, + CompletionParams, + InitializeParams, + InitializeResult, + ProgressParams, + Position, + ShutdownResponse, + TextDocumentCompletionResponse, + TextDocumentIdentifier, + WorkDoneProgressBegin, +) +from pygls.protocol import ( + default_converter, + JsonRPCProtocol, + JsonRPCRequestMessage, + JsonRPCResponseMessage, + JsonRPCNotification, +) + +EXAMPLE_NOTIFICATION = "example/notification" +EXAMPLE_REQUEST = "example/request" + + +@attrs.define +class IntResult: + id: str + result: int + jsonrpc: str = attrs.field(default="2.0") + + +@attrs.define +class ExampleParams: + @attrs.define + class InnerType: + inner_field: str + + field_a: str + field_b: Optional[InnerType] = None + + +@attrs.define +class ExampleNotification: + jsonrpc: str = attrs.field(default="2.0") + method: str = EXAMPLE_NOTIFICATION + params: ExampleParams = attrs.field(default=None) + + +@attrs.define +class ExampleRequest: + id: str + jsonrpc: str = attrs.field(default="2.0") + method: str = EXAMPLE_REQUEST + params: ExampleParams = attrs.field(default=None) + + +EXAMPLE_LSP_METHODS_MAP = { + EXAMPLE_NOTIFICATION: (ExampleNotification, None, ExampleParams, None), + EXAMPLE_REQUEST: (ExampleRequest, None, ExampleParams, None), +} + + +class ExampleProtocol(JsonRPCProtocol): + def get_message_type(self, method: str): + return EXAMPLE_LSP_METHODS_MAP.get(method, (None,))[0] + + +@pytest.fixture() +def protocol(): + return ExampleProtocol(None, default_converter()) + + +def test_deserialize_notification_message_valid_params(protocol): + params = f""" + {{ + "jsonrpc": "2.0", + "method": "{EXAMPLE_NOTIFICATION}", + "params": {{ + "fieldA": "test_a", + "fieldB": {{ + "innerField": "test_inner" + }} + }} + }} + """ + + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance( + result, ExampleNotification + ), f"Expected FeatureRequest instance, got {result}" + assert result.jsonrpc == "2.0" + assert result.method == EXAMPLE_NOTIFICATION + + assert isinstance(result.params, ExampleParams) + assert result.params.field_a == "test_a" + + assert isinstance(result.params.field_b, ExampleParams.InnerType) + assert result.params.field_b.inner_field == "test_inner" + + +def test_deserialize_notification_message_unknown_type(protocol): + params = """ + { + "jsonrpc": "2.0", + "method": "random", + "params": { + "field_a": "test_a", + "field_b": { + "inner_field": "test_inner" + } + } + } + """ + + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, JsonRPCNotification) + assert result.jsonrpc == "2.0" + assert result.method == "random" + + assert result.params.field_a == "test_a" + assert result.params.field_b.inner_field == "test_inner" + + +def test_deserialize_notification_message_bad_params_should_raise_error(protocol): + params = f""" + {{ + "jsonrpc": "2.0", + "method": "{EXAMPLE_NOTIFICATION}", + "params": {{ + "field_a": "test_a", + "field_b": {{ + "wrong_field_name": "test_inner" + }} + }} + }} + """ + + with pytest.raises(JsonRpcInvalidParams): + json.loads(params, object_hook=protocol._deserialize_message) + + +def test_deserialize_response_message_custom_converter(): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "result": "1" + } + """ + + # Just for fun, let's create a converter that reverses all the keys in a dict. + # + @attrs.define + class egasseM: + cprnosj: str + di: str + tluser: str + + def structure_hook(obj, cls): + params = {k[::-1]: v for k, v in obj.items()} + return cls(**params) + + def custom_converter(): + converter = default_converter() + converter.register_structure_hook(egasseM, structure_hook) + return converter + + protocol = JsonRPCProtocol(None, custom_converter()) + protocol._result_types["id"] = egasseM + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, egasseM) + assert result.cprnosj == "2.0" + assert result.di == "id" + assert result.tluser == "1" + + +@pytest.mark.parametrize( + "method, params, expected", + [ + ( + # Known notification type. + PROGRESS, + ProgressParams( + token="id1", + value=WorkDoneProgressBegin( + title="Begin progress", + percentage=0, + ), + ), + { + "jsonrpc": "2.0", + "method": "$/progress", + "params": { + "token": "id1", + "value": { + "kind": "begin", + "percentage": 0, + "title": "Begin progress", + }, + }, + }, + ), + ( + # Custom notification type. + EXAMPLE_NOTIFICATION, + ExampleParams( + field_a="field one", + field_b=ExampleParams.InnerType(inner_field="field two"), + ), + { + "jsonrpc": "2.0", + "method": EXAMPLE_NOTIFICATION, + "params": { + "fieldA": "field one", + "fieldB": { + "innerField": "field two", + }, + }, + }, + ), + ( + # Custom notification with dict params. + EXAMPLE_NOTIFICATION, + {"fieldA": "field one", "fieldB": {"innerField": "field two"}}, + { + "jsonrpc": "2.0", + "method": EXAMPLE_NOTIFICATION, + "params": { + "fieldA": "field one", + "fieldB": { + "innerField": "field two", + }, + }, + }, + ), + ], +) +def test_serialize_notification_message(method, params, expected): + """ + Ensure that we can serialize notification messages, retaining all + expected fields. + """ + + buffer = io.StringIO() + + protocol = JsonRPCProtocol(None, default_converter()) + protocol._send_only_body = True + protocol.connection_made(buffer) + + protocol.notify(method, params=params) + actual = json.loads(buffer.getvalue()) + + assert actual == expected + + +def test_deserialize_response_message(protocol): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "result": "1" + } + """ + protocol._result_types["id"] = IntResult + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, IntResult) + assert result.jsonrpc == "2.0" + assert result.id == "id" + assert result.result == 1 + + +def test_deserialize_response_message_unknown_type(protocol): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "result": { + "field_a": "test_a", + "field_b": { + "inner_field": "test_inner" + } + } + } + """ + protocol._result_types["id"] = JsonRPCResponseMessage + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, JsonRPCResponseMessage) + assert result.jsonrpc == "2.0" + assert result.id == "id" + + assert result.result.field_a == "test_a" + assert result.result.field_b.inner_field == "test_inner" + + +def test_deserialize_request_message_with_registered_type(protocol): + params = f""" + {{ + "jsonrpc": "2.0", + "id": "id", + "method": "{EXAMPLE_REQUEST}", + "params": {{ + "fieldA": "test_a", + "fieldB": {{ + "innerField": "test_inner" + }} + }} + }} + """ + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, ExampleRequest) + assert result.jsonrpc == "2.0" + assert result.id == "id" + assert result.method == EXAMPLE_REQUEST + + assert isinstance(result.params, ExampleParams) + assert result.params.field_a == "test_a" + + assert isinstance(result.params.field_b, ExampleParams.InnerType) + assert result.params.field_b.inner_field == "test_inner" + + +def test_deserialize_request_message_without_registered_type(protocol): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "method": "random", + "params": { + "field_a": "test_a", + "field_b": { + "inner_field": "test_inner" + } + } + } + """ + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, JsonRPCRequestMessage) + assert result.jsonrpc == "2.0" + assert result.id == "id" + assert result.method == "random" + + assert result.params.field_a == "test_a" + assert result.params.field_b.inner_field == "test_inner" + + +@pytest.mark.parametrize( + "msg_type, result, expected", + [ + (ShutdownResponse, None, {"jsonrpc": "2.0", "id": "1", "result": None}), + ( + TextDocumentCompletionResponse, + [ + CompletionItem(label="example-one"), + CompletionItem( + label="example-two", + kind=CompletionItemKind.Class, + preselect=False, + deprecated=True, + ), + ], + { + "jsonrpc": "2.0", + "id": "1", + "result": [ + {"label": "example-one"}, + { + "label": "example-two", + "kind": 7, # CompletionItemKind.Class + "preselect": False, + "deprecated": True, + }, + ], + }, + ), + ( # Unknown type with object params. + JsonRPCResponseMessage, + ExampleParams( + field_a="field one", + field_b=ExampleParams.InnerType(inner_field="field two"), + ), + { + "jsonrpc": "2.0", + "id": "1", + "result": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ( # Unknown type with dict params. + JsonRPCResponseMessage, + {"fieldA": "field one", "fieldB": {"innerField": "field two"}}, + { + "jsonrpc": "2.0", + "id": "1", + "result": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ], +) +def test_serialize_response_message(msg_type, result, expected): + """ + Ensure that we can serialize response messages, retaining all expected + fields. + """ + + buffer = io.StringIO() + + protocol = JsonRPCProtocol(None, default_converter()) + protocol._send_only_body = True + protocol.connection_made(buffer) + + protocol._result_types["1"] = msg_type + + protocol._send_response("1", result=result) + actual = json.loads(buffer.getvalue()) + + assert actual == expected + + +@pytest.mark.parametrize( + "method, params, expected", + [ + ( + TEXT_DOCUMENT_COMPLETION, + CompletionParams( + text_document=TextDocumentIdentifier(uri="file:///file.txt"), + position=Position(line=1, character=0), + ), + { + "jsonrpc": "2.0", + "id": "1", + "method": TEXT_DOCUMENT_COMPLETION, + "params": { + "textDocument": {"uri": "file:///file.txt"}, + "position": {"line": 1, "character": 0}, + }, + }, + ), + ( # Unknown type with object params. + EXAMPLE_REQUEST, + ExampleParams( + field_a="field one", + field_b=ExampleParams.InnerType(inner_field="field two"), + ), + { + "jsonrpc": "2.0", + "id": "1", + "method": EXAMPLE_REQUEST, + "params": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ( # Unknown type with dict params. + EXAMPLE_REQUEST, + {"fieldA": "field one", "fieldB": {"innerField": "field two"}}, + { + "jsonrpc": "2.0", + "id": "1", + "method": EXAMPLE_REQUEST, + "params": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ], +) +def test_serialize_request_message(method, params, expected): + """ + Ensure that we can serialize request messages, retaining all expected + fields. + """ + + buffer = io.StringIO() + + protocol = JsonRPCProtocol(None, default_converter()) + protocol._send_only_body = True + protocol.connection_made(buffer) + + protocol.send_request(method, params, callback=None, msg_id="1") + actual = json.loads(buffer.getvalue()) + + assert actual == expected + + +def test_data_received_without_content_type(client_server): + _, server = client_server + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "test", + "params": 1, + } + ) + message = "\r\n".join( + ( + "Content-Length: " + str(len(body)), + "", + body, + ) + ) + data = bytes(message, "utf-8") + server.lsp.data_received(data) + + +def test_data_received_content_type_first_should_handle_message(client_server): + _, server = client_server + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "test", + "params": 1, + } + ) + message = "\r\n".join( + ( + "Content-Type: application/vscode-jsonrpc; charset=utf-8", + "Content-Length: " + str(len(body)), + "", + body, + ) + ) + data = bytes(message, "utf-8") + server.lsp.data_received(data) + + +def dummy_message(param=1): + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "test", + "params": param, + } + ) + message = "\r\n".join( + ( + "Content-Length: " + str(len(body)), + "Content-Type: application/vscode-jsonrpc; charset=utf-8", + "", + body, + ) + ) + return bytes(message, "utf-8") + + +def test_data_received_single_message_should_handle_message(client_server): + _, server = client_server + data = dummy_message() + server.lsp.data_received(data) + + +def test_data_received_partial_message_should_handle_message(client_server): + _, server = client_server + data = dummy_message() + partial = len(data) - 5 + server.lsp.data_received(data[:partial]) + server.lsp.data_received(data[partial:]) + + +def test_data_received_multi_message_should_handle_messages(client_server): + _, server = client_server + messages = (dummy_message(i) for i in range(3)) + data = b"".join(messages) + server.lsp.data_received(data) + + +def test_data_received_error_should_raise_jsonrpc_error(client_server): + _, server = client_server + body = json.dumps( + { + "jsonrpc": "2.0", + "id": "err", + "error": { + "code": -1, + "message": "message for you sir", + }, + } + ) + message = "\r\n".join( + [ + "Content-Length: " + str(len(body)), + "Content-Type: application/vscode-jsonrpc; charset=utf-8", + "", + body, + ] + ).encode("utf-8") + future = server.lsp._request_futures["err"] = Future() + server.lsp.data_received(message) + with pytest.raises(JsonRpcException, match="message for you sir"): + future.result() + + +def test_initialize_should_return_server_capabilities(client_server): + _, server = client_server + params = InitializeParams( + process_id=1234, + root_uri=Path(__file__).parent.as_uri(), + capabilities=ClientCapabilities(), + ) + + server_capabilities = server.lsp.lsp_initialize(params) + + assert isinstance(server_capabilities, InitializeResult) + + +def test_ignore_unknown_notification(client_server): + _, server = client_server + + fn = server.lsp._execute_notification + server.lsp._execute_notification = Mock() + + server.lsp._handle_notification("random/notification", None) + assert not server.lsp._execute_notification.called + + # Remove mock + server.lsp._execute_notification = fn |