summaryrefslogtreecommitdiffstats
path: root/tests/test_protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_protocol.py')
-rw-r--r--tests/test_protocol.py660
1 files changed, 660 insertions, 0 deletions
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
new file mode 100644
index 0000000..63b689f
--- /dev/null
+++ b/tests/test_protocol.py
@@ -0,0 +1,660 @@
+############################################################################
+# Copyright(c) Open Law Library. All rights reserved. #
+# See ThirdPartyNotices.txt in the project root for additional notices. #
+# #
+# Licensed under the Apache License, Version 2.0 (the "License") #
+# you may not use this file except in compliance with the License. #
+# You may obtain a copy of the License at #
+# #
+# http: // www.apache.org/licenses/LICENSE-2.0 #
+# #
+# Unless required by applicable law or agreed to in writing, software #
+# distributed under the License is distributed on an "AS IS" BASIS, #
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
+# See the License for the specific language governing permissions and #
+# limitations under the License. #
+############################################################################
+import io
+import json
+from concurrent.futures import Future
+from pathlib import Path
+from typing import Optional
+from unittest.mock import Mock
+
+import attrs
+import pytest
+
+from pygls.exceptions import JsonRpcException, JsonRpcInvalidParams
+from lsprotocol.types import (
+ PROGRESS,
+ TEXT_DOCUMENT_COMPLETION,
+ ClientCapabilities,
+ CompletionItem,
+ CompletionItemKind,
+ CompletionParams,
+ InitializeParams,
+ InitializeResult,
+ ProgressParams,
+ Position,
+ ShutdownResponse,
+ TextDocumentCompletionResponse,
+ TextDocumentIdentifier,
+ WorkDoneProgressBegin,
+)
+from pygls.protocol import (
+ default_converter,
+ JsonRPCProtocol,
+ JsonRPCRequestMessage,
+ JsonRPCResponseMessage,
+ JsonRPCNotification,
+)
+
+EXAMPLE_NOTIFICATION = "example/notification"
+EXAMPLE_REQUEST = "example/request"
+
+
+@attrs.define
+class IntResult:
+ id: str
+ result: int
+ jsonrpc: str = attrs.field(default="2.0")
+
+
+@attrs.define
+class ExampleParams:
+ @attrs.define
+ class InnerType:
+ inner_field: str
+
+ field_a: str
+ field_b: Optional[InnerType] = None
+
+
+@attrs.define
+class ExampleNotification:
+ jsonrpc: str = attrs.field(default="2.0")
+ method: str = EXAMPLE_NOTIFICATION
+ params: ExampleParams = attrs.field(default=None)
+
+
+@attrs.define
+class ExampleRequest:
+ id: str
+ jsonrpc: str = attrs.field(default="2.0")
+ method: str = EXAMPLE_REQUEST
+ params: ExampleParams = attrs.field(default=None)
+
+
+EXAMPLE_LSP_METHODS_MAP = {
+ EXAMPLE_NOTIFICATION: (ExampleNotification, None, ExampleParams, None),
+ EXAMPLE_REQUEST: (ExampleRequest, None, ExampleParams, None),
+}
+
+
+class ExampleProtocol(JsonRPCProtocol):
+ def get_message_type(self, method: str):
+ return EXAMPLE_LSP_METHODS_MAP.get(method, (None,))[0]
+
+
+@pytest.fixture()
+def protocol():
+ return ExampleProtocol(None, default_converter())
+
+
+def test_deserialize_notification_message_valid_params(protocol):
+ params = f"""
+ {{
+ "jsonrpc": "2.0",
+ "method": "{EXAMPLE_NOTIFICATION}",
+ "params": {{
+ "fieldA": "test_a",
+ "fieldB": {{
+ "innerField": "test_inner"
+ }}
+ }}
+ }}
+ """
+
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(
+ result, ExampleNotification
+ ), f"Expected FeatureRequest instance, got {result}"
+ assert result.jsonrpc == "2.0"
+ assert result.method == EXAMPLE_NOTIFICATION
+
+ assert isinstance(result.params, ExampleParams)
+ assert result.params.field_a == "test_a"
+
+ assert isinstance(result.params.field_b, ExampleParams.InnerType)
+ assert result.params.field_b.inner_field == "test_inner"
+
+
+def test_deserialize_notification_message_unknown_type(protocol):
+ params = """
+ {
+ "jsonrpc": "2.0",
+ "method": "random",
+ "params": {
+ "field_a": "test_a",
+ "field_b": {
+ "inner_field": "test_inner"
+ }
+ }
+ }
+ """
+
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(result, JsonRPCNotification)
+ assert result.jsonrpc == "2.0"
+ assert result.method == "random"
+
+ assert result.params.field_a == "test_a"
+ assert result.params.field_b.inner_field == "test_inner"
+
+
+def test_deserialize_notification_message_bad_params_should_raise_error(protocol):
+ params = f"""
+ {{
+ "jsonrpc": "2.0",
+ "method": "{EXAMPLE_NOTIFICATION}",
+ "params": {{
+ "field_a": "test_a",
+ "field_b": {{
+ "wrong_field_name": "test_inner"
+ }}
+ }}
+ }}
+ """
+
+ with pytest.raises(JsonRpcInvalidParams):
+ json.loads(params, object_hook=protocol._deserialize_message)
+
+
+def test_deserialize_response_message_custom_converter():
+ params = """
+ {
+ "jsonrpc": "2.0",
+ "id": "id",
+ "result": "1"
+ }
+ """
+
+ # Just for fun, let's create a converter that reverses all the keys in a dict.
+ #
+ @attrs.define
+ class egasseM:
+ cprnosj: str
+ di: str
+ tluser: str
+
+ def structure_hook(obj, cls):
+ params = {k[::-1]: v for k, v in obj.items()}
+ return cls(**params)
+
+ def custom_converter():
+ converter = default_converter()
+ converter.register_structure_hook(egasseM, structure_hook)
+ return converter
+
+ protocol = JsonRPCProtocol(None, custom_converter())
+ protocol._result_types["id"] = egasseM
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(result, egasseM)
+ assert result.cprnosj == "2.0"
+ assert result.di == "id"
+ assert result.tluser == "1"
+
+
+@pytest.mark.parametrize(
+ "method, params, expected",
+ [
+ (
+ # Known notification type.
+ PROGRESS,
+ ProgressParams(
+ token="id1",
+ value=WorkDoneProgressBegin(
+ title="Begin progress",
+ percentage=0,
+ ),
+ ),
+ {
+ "jsonrpc": "2.0",
+ "method": "$/progress",
+ "params": {
+ "token": "id1",
+ "value": {
+ "kind": "begin",
+ "percentage": 0,
+ "title": "Begin progress",
+ },
+ },
+ },
+ ),
+ (
+ # Custom notification type.
+ EXAMPLE_NOTIFICATION,
+ ExampleParams(
+ field_a="field one",
+ field_b=ExampleParams.InnerType(inner_field="field two"),
+ ),
+ {
+ "jsonrpc": "2.0",
+ "method": EXAMPLE_NOTIFICATION,
+ "params": {
+ "fieldA": "field one",
+ "fieldB": {
+ "innerField": "field two",
+ },
+ },
+ },
+ ),
+ (
+ # Custom notification with dict params.
+ EXAMPLE_NOTIFICATION,
+ {"fieldA": "field one", "fieldB": {"innerField": "field two"}},
+ {
+ "jsonrpc": "2.0",
+ "method": EXAMPLE_NOTIFICATION,
+ "params": {
+ "fieldA": "field one",
+ "fieldB": {
+ "innerField": "field two",
+ },
+ },
+ },
+ ),
+ ],
+)
+def test_serialize_notification_message(method, params, expected):
+ """
+ Ensure that we can serialize notification messages, retaining all
+ expected fields.
+ """
+
+ buffer = io.StringIO()
+
+ protocol = JsonRPCProtocol(None, default_converter())
+ protocol._send_only_body = True
+ protocol.connection_made(buffer)
+
+ protocol.notify(method, params=params)
+ actual = json.loads(buffer.getvalue())
+
+ assert actual == expected
+
+
+def test_deserialize_response_message(protocol):
+ params = """
+ {
+ "jsonrpc": "2.0",
+ "id": "id",
+ "result": "1"
+ }
+ """
+ protocol._result_types["id"] = IntResult
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(result, IntResult)
+ assert result.jsonrpc == "2.0"
+ assert result.id == "id"
+ assert result.result == 1
+
+
+def test_deserialize_response_message_unknown_type(protocol):
+ params = """
+ {
+ "jsonrpc": "2.0",
+ "id": "id",
+ "result": {
+ "field_a": "test_a",
+ "field_b": {
+ "inner_field": "test_inner"
+ }
+ }
+ }
+ """
+ protocol._result_types["id"] = JsonRPCResponseMessage
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(result, JsonRPCResponseMessage)
+ assert result.jsonrpc == "2.0"
+ assert result.id == "id"
+
+ assert result.result.field_a == "test_a"
+ assert result.result.field_b.inner_field == "test_inner"
+
+
+def test_deserialize_request_message_with_registered_type(protocol):
+ params = f"""
+ {{
+ "jsonrpc": "2.0",
+ "id": "id",
+ "method": "{EXAMPLE_REQUEST}",
+ "params": {{
+ "fieldA": "test_a",
+ "fieldB": {{
+ "innerField": "test_inner"
+ }}
+ }}
+ }}
+ """
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(result, ExampleRequest)
+ assert result.jsonrpc == "2.0"
+ assert result.id == "id"
+ assert result.method == EXAMPLE_REQUEST
+
+ assert isinstance(result.params, ExampleParams)
+ assert result.params.field_a == "test_a"
+
+ assert isinstance(result.params.field_b, ExampleParams.InnerType)
+ assert result.params.field_b.inner_field == "test_inner"
+
+
+def test_deserialize_request_message_without_registered_type(protocol):
+ params = """
+ {
+ "jsonrpc": "2.0",
+ "id": "id",
+ "method": "random",
+ "params": {
+ "field_a": "test_a",
+ "field_b": {
+ "inner_field": "test_inner"
+ }
+ }
+ }
+ """
+ result = json.loads(params, object_hook=protocol._deserialize_message)
+
+ assert isinstance(result, JsonRPCRequestMessage)
+ assert result.jsonrpc == "2.0"
+ assert result.id == "id"
+ assert result.method == "random"
+
+ assert result.params.field_a == "test_a"
+ assert result.params.field_b.inner_field == "test_inner"
+
+
+@pytest.mark.parametrize(
+ "msg_type, result, expected",
+ [
+ (ShutdownResponse, None, {"jsonrpc": "2.0", "id": "1", "result": None}),
+ (
+ TextDocumentCompletionResponse,
+ [
+ CompletionItem(label="example-one"),
+ CompletionItem(
+ label="example-two",
+ kind=CompletionItemKind.Class,
+ preselect=False,
+ deprecated=True,
+ ),
+ ],
+ {
+ "jsonrpc": "2.0",
+ "id": "1",
+ "result": [
+ {"label": "example-one"},
+ {
+ "label": "example-two",
+ "kind": 7, # CompletionItemKind.Class
+ "preselect": False,
+ "deprecated": True,
+ },
+ ],
+ },
+ ),
+ ( # Unknown type with object params.
+ JsonRPCResponseMessage,
+ ExampleParams(
+ field_a="field one",
+ field_b=ExampleParams.InnerType(inner_field="field two"),
+ ),
+ {
+ "jsonrpc": "2.0",
+ "id": "1",
+ "result": {
+ "fieldA": "field one",
+ "fieldB": {"innerField": "field two"},
+ },
+ },
+ ),
+ ( # Unknown type with dict params.
+ JsonRPCResponseMessage,
+ {"fieldA": "field one", "fieldB": {"innerField": "field two"}},
+ {
+ "jsonrpc": "2.0",
+ "id": "1",
+ "result": {
+ "fieldA": "field one",
+ "fieldB": {"innerField": "field two"},
+ },
+ },
+ ),
+ ],
+)
+def test_serialize_response_message(msg_type, result, expected):
+ """
+ Ensure that we can serialize response messages, retaining all expected
+ fields.
+ """
+
+ buffer = io.StringIO()
+
+ protocol = JsonRPCProtocol(None, default_converter())
+ protocol._send_only_body = True
+ protocol.connection_made(buffer)
+
+ protocol._result_types["1"] = msg_type
+
+ protocol._send_response("1", result=result)
+ actual = json.loads(buffer.getvalue())
+
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "method, params, expected",
+ [
+ (
+ TEXT_DOCUMENT_COMPLETION,
+ CompletionParams(
+ text_document=TextDocumentIdentifier(uri="file:///file.txt"),
+ position=Position(line=1, character=0),
+ ),
+ {
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": TEXT_DOCUMENT_COMPLETION,
+ "params": {
+ "textDocument": {"uri": "file:///file.txt"},
+ "position": {"line": 1, "character": 0},
+ },
+ },
+ ),
+ ( # Unknown type with object params.
+ EXAMPLE_REQUEST,
+ ExampleParams(
+ field_a="field one",
+ field_b=ExampleParams.InnerType(inner_field="field two"),
+ ),
+ {
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": EXAMPLE_REQUEST,
+ "params": {
+ "fieldA": "field one",
+ "fieldB": {"innerField": "field two"},
+ },
+ },
+ ),
+ ( # Unknown type with dict params.
+ EXAMPLE_REQUEST,
+ {"fieldA": "field one", "fieldB": {"innerField": "field two"}},
+ {
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": EXAMPLE_REQUEST,
+ "params": {
+ "fieldA": "field one",
+ "fieldB": {"innerField": "field two"},
+ },
+ },
+ ),
+ ],
+)
+def test_serialize_request_message(method, params, expected):
+ """
+ Ensure that we can serialize request messages, retaining all expected
+ fields.
+ """
+
+ buffer = io.StringIO()
+
+ protocol = JsonRPCProtocol(None, default_converter())
+ protocol._send_only_body = True
+ protocol.connection_made(buffer)
+
+ protocol.send_request(method, params, callback=None, msg_id="1")
+ actual = json.loads(buffer.getvalue())
+
+ assert actual == expected
+
+
+def test_data_received_without_content_type(client_server):
+ _, server = client_server
+ body = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "method": "test",
+ "params": 1,
+ }
+ )
+ message = "\r\n".join(
+ (
+ "Content-Length: " + str(len(body)),
+ "",
+ body,
+ )
+ )
+ data = bytes(message, "utf-8")
+ server.lsp.data_received(data)
+
+
+def test_data_received_content_type_first_should_handle_message(client_server):
+ _, server = client_server
+ body = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "method": "test",
+ "params": 1,
+ }
+ )
+ message = "\r\n".join(
+ (
+ "Content-Type: application/vscode-jsonrpc; charset=utf-8",
+ "Content-Length: " + str(len(body)),
+ "",
+ body,
+ )
+ )
+ data = bytes(message, "utf-8")
+ server.lsp.data_received(data)
+
+
+def dummy_message(param=1):
+ body = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "method": "test",
+ "params": param,
+ }
+ )
+ message = "\r\n".join(
+ (
+ "Content-Length: " + str(len(body)),
+ "Content-Type: application/vscode-jsonrpc; charset=utf-8",
+ "",
+ body,
+ )
+ )
+ return bytes(message, "utf-8")
+
+
+def test_data_received_single_message_should_handle_message(client_server):
+ _, server = client_server
+ data = dummy_message()
+ server.lsp.data_received(data)
+
+
+def test_data_received_partial_message_should_handle_message(client_server):
+ _, server = client_server
+ data = dummy_message()
+ partial = len(data) - 5
+ server.lsp.data_received(data[:partial])
+ server.lsp.data_received(data[partial:])
+
+
+def test_data_received_multi_message_should_handle_messages(client_server):
+ _, server = client_server
+ messages = (dummy_message(i) for i in range(3))
+ data = b"".join(messages)
+ server.lsp.data_received(data)
+
+
+def test_data_received_error_should_raise_jsonrpc_error(client_server):
+ _, server = client_server
+ body = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "id": "err",
+ "error": {
+ "code": -1,
+ "message": "message for you sir",
+ },
+ }
+ )
+ message = "\r\n".join(
+ [
+ "Content-Length: " + str(len(body)),
+ "Content-Type: application/vscode-jsonrpc; charset=utf-8",
+ "",
+ body,
+ ]
+ ).encode("utf-8")
+ future = server.lsp._request_futures["err"] = Future()
+ server.lsp.data_received(message)
+ with pytest.raises(JsonRpcException, match="message for you sir"):
+ future.result()
+
+
+def test_initialize_should_return_server_capabilities(client_server):
+ _, server = client_server
+ params = InitializeParams(
+ process_id=1234,
+ root_uri=Path(__file__).parent.as_uri(),
+ capabilities=ClientCapabilities(),
+ )
+
+ server_capabilities = server.lsp.lsp_initialize(params)
+
+ assert isinstance(server_capabilities, InitializeResult)
+
+
+def test_ignore_unknown_notification(client_server):
+ _, server = client_server
+
+ fn = server.lsp._execute_notification
+ server.lsp._execute_notification = Mock()
+
+ server.lsp._handle_notification("random/notification", None)
+ assert not server.lsp._execute_notification.called
+
+ # Remove mock
+ server.lsp._execute_notification = fn