diff options
Diffstat (limited to 'tests')
59 files changed, 7817 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..ea4bd6c --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,28 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import pytest + +from pygls import IS_WIN + +unix_only = pytest.mark.skipif(IS_WIN, reason="Unix only") +windows_only = pytest.mark.skipif(not IS_WIN, reason="Windows only") + +CMD_ASYNC = "cmd_async" +CMD_SYNC = "cmd_sync" +CMD_THREAD = "cmd_thread" diff --git a/tests/_init_server_stall_fix_hack.py b/tests/_init_server_stall_fix_hack.py new file mode 100644 index 0000000..04895b0 --- /dev/null +++ b/tests/_init_server_stall_fix_hack.py @@ -0,0 +1,33 @@ +""" +It would be great to find the real underlying issue here, but without these +retries we get annoying flakey test errors. So it's preferable to hack this +fix to actually guarantee it doesn't generate false negatives in the test +suite. +""" +import os +import concurrent + +RETRIES = 3 + + +def retry_stalled_init_fix_hack(): + if "DISABLE_TIMEOUT" in os.environ: + return lambda f: f + + def decorator(func): + def newfn(*args, **kwargs): + attempt = 0 + while attempt < RETRIES: + try: + return func(*args, **kwargs) + except concurrent.futures._base.TimeoutError: + print( + "\n\nRetrying timeouted test server init " + "%d of %d\n" % (attempt, RETRIES) + ) + attempt += 1 + return func(*args, **kwargs) + + return newfn + + return decorator diff --git a/tests/client.py b/tests/client.py new file mode 100644 index 0000000..9be8752 --- /dev/null +++ b/tests/client.py @@ -0,0 +1,180 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +import logging +import pathlib +import sys +from concurrent.futures import Future +from typing import Dict +from typing import List +from typing import Type + +import pytest +import pytest_asyncio +from lsprotocol import types + +from pygls import IS_PYODIDE +from pygls import uris +from pygls.exceptions import JsonRpcMethodNotFound +from pygls.lsp.client import BaseLanguageClient +from pygls.protocol import LanguageServerProtocol +from pygls.protocol import default_converter + +logger = logging.getLogger(__name__) + + +class LanguageClientProtocol(LanguageServerProtocol): + """An extended protocol class with extra methods that are useful for testing.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._notification_futures = {} + + def _handle_notification(self, method_name, params): + if method_name == types.CANCEL_REQUEST: + self._handle_cancel_notification(params.id) + return + + future = self._notification_futures.pop(method_name, None) + if future: + future.set_result(params) + + try: + handler = self._get_handler(method_name) + self._execute_notification(handler, params) + except (KeyError, JsonRpcMethodNotFound): + logger.warning("Ignoring notification for unknown method '%s'", method_name) + except Exception: + logger.exception( + "Failed to handle notification '%s': %s", method_name, params + ) + + def wait_for_notification(self, method: str, callback=None): + future: Future = Future() + if callback: + + def wrapper(future: Future): + result = future.result() + callback(result) + + future.add_done_callback(wrapper) + + self._notification_futures[method] = future + return future + + def wait_for_notification_async(self, method: str): + future = self.wait_for_notification(method) + return asyncio.wrap_future(future) + + +class LanguageClient(BaseLanguageClient): + """Language client used to drive test cases.""" + + def __init__( + self, + protocol_cls: Type[LanguageClientProtocol] = LanguageClientProtocol, + *args, + **kwargs, + ): + super().__init__( + "pygls-test-client", "v1", protocol_cls=protocol_cls, *args, **kwargs + ) + + self.diagnostics: Dict[str, List[types.Diagnostic]] = {} + """Used to hold any recieved diagnostics.""" + + self.messages: List[types.ShowMessageParams] = [] + """Holds any received ``window/showMessage`` requests.""" + + self.log_messages: List[types.LogMessageParams] = [] + """Holds any received ``window/logMessage`` requests.""" + + async def wait_for_notification(self, method: str): + """Block until a notification with the given method is received. + + Parameters + ---------- + method + The notification method to wait for, e.g. ``textDocument/publishDiagnostics`` + """ + return await self.protocol.wait_for_notification_async(method) + + +def make_test_lsp_client() -> LanguageClient: + """Construct a new test client instance with the handlers needed to capture + additional responses from the server.""" + + client = LanguageClient(converter_factory=default_converter) + + @client.feature(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS) + def publish_diagnostics( + client: LanguageClient, params: types.PublishDiagnosticsParams + ): + client.diagnostics[params.uri] = params.diagnostics + + @client.feature(types.WINDOW_LOG_MESSAGE) + def log_message(client: LanguageClient, params: types.LogMessageParams): + client.log_messages.append(params) + + levels = ["ERROR: ", "WARNING: ", "INFO: ", "LOG: "] + log_level = levels[params.type.value - 1] + + print(log_level, params.message) + + @client.feature(types.WINDOW_SHOW_MESSAGE) + def show_message(client: LanguageClient, params): + client.messages.append(params) + + return client + + +def create_client_for_server(server_name: str): + """Automate the process of creating a language client connected to the given server + and tearing it down again. + """ + + @pytest_asyncio.fixture + async def fixture_func(): + if IS_PYODIDE: + pytest.skip("not available in pyodide") + + client = make_test_lsp_client() + server_dir = pathlib.Path(__file__, "..", "..", "examples", "servers").resolve() + root_dir = pathlib.Path(__file__, "..", "..", "examples", "workspace").resolve() + + await client.start_io(sys.executable, str(server_dir / server_name)) + + # Initialize the server + response = await client.initialize_async( + types.InitializeParams( + capabilities=types.ClientCapabilities(), + root_uri=uris.from_fs_path(root_dir), + ) + ) + assert response is not None + + yield client, response + + await client.shutdown_async(None) + client.exit(None) + + await client.stop() + + return fixture_func diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5dd2ecb --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,122 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +import pathlib + +import pytest +from lsprotocol import types, converters + +from pygls import uris, IS_PYODIDE +from pygls.feature_manager import FeatureManager +from pygls.workspace import Workspace + +from .ls_setup import ( + NativeClientServer, + PyodideClientServer, + setup_ls_features, +) + +from .client import create_client_for_server + +DOC = """document +for +testing +with "😋" unicode. +""" +DOC_URI = uris.from_fs_path(__file__) or "" + + +ClientServer = NativeClientServer +if IS_PYODIDE: + ClientServer = PyodideClientServer + + +@pytest.fixture(autouse=False) +def client_server(request): + if hasattr(request, "param"): + ConfiguredClientServer = request.param + client_server = ConfiguredClientServer() + else: + client_server = ClientServer() + setup_ls_features(client_server.server) + + client_server.start() + client, server = client_server + + yield client, server + + client_server.stop() + + +@pytest.fixture(scope="session") +def uri_for(): + """Returns the uri corresponsing to a file in the example workspace.""" + base_dir = pathlib.Path( + __file__, "..", "..", "examples", "servers", "workspace" + ).resolve() + + def fn(*args): + fpath = pathlib.Path(base_dir, *args) + return uris.from_fs_path(str(fpath)) + + return fn + + +@pytest.fixture() +def event_loop(): + """Redefine `pytest-asyncio's default event_loop fixture to match the scope + of our client fixture.""" + + policy = asyncio.get_event_loop_policy() + + loop = policy.new_event_loop() + yield loop + + try: + # Not implemented on pyodide + loop.close() + except NotImplementedError: + pass + + +@pytest.fixture(scope="session") +def server_dir(): + """Returns the directory where all the example language servers live""" + path = pathlib.Path(__file__) / ".." / ".." / "examples" / "servers" + return path.resolve() + + +code_action_client = create_client_for_server("code_actions.py") +inlay_hints_client = create_client_for_server("inlay_hints.py") +json_server_client = create_client_for_server("json_server.py") + + +@pytest.fixture +def feature_manager(): + """Return a feature manager""" + return FeatureManager(None, converters.get_converter()) + + +@pytest.fixture +def workspace(tmpdir): + """Return a workspace.""" + return Workspace( + uris.from_fs_path(str(tmpdir)), + sync_kind=types.TextDocumentSyncKind.Incremental, + ) diff --git a/tests/ls_setup.py b/tests/ls_setup.py new file mode 100644 index 0000000..e52c06c --- /dev/null +++ b/tests/ls_setup.py @@ -0,0 +1,169 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +import json +import os +import threading + +import pytest +from lsprotocol.types import ( + EXIT, + INITIALIZE, + SHUTDOWN, + ClientCapabilities, + InitializeParams, +) +from pygls.server import LanguageServer + + +from . import CMD_ASYNC, CMD_SYNC, CMD_THREAD +from ._init_server_stall_fix_hack import retry_stalled_init_fix_hack + + +CALL_TIMEOUT = 3 + + +def setup_ls_features(server): + # Commands + @server.command(CMD_ASYNC) + async def cmd_test3(ls, *args): # pylint: disable=unused-variable + return True, threading.get_ident() + + @server.thread() + @server.command(CMD_THREAD) + def cmd_test1(ls, *args): # pylint: disable=unused-variable + return True, threading.get_ident() + + @server.command(CMD_SYNC) + def cmd_test2(ls, *args): # pylint: disable=unused-variable + return True, threading.get_ident() + + +class PyodideTestTransportAdapter: + """Transort adapter that's only useful for tests in a pyodide environment.""" + + def __init__(self, dest: LanguageServer): + self.dest = dest + + def close(self): + ... + + def write(self, data): + object_hook = self.dest.lsp._deserialize_message + self.dest.lsp._procedure_handler(json.loads(data, object_hook=object_hook)) + + +class PyodideClientServer: + """Implementation of the `client_server` fixture for use in a pyodide + environment.""" + + def __init__(self, LS=LanguageServer): + self.server = LS("pygls-server", "v1") + self.client = LS("pygls-client", "v1") + + self.server.lsp.connection_made(PyodideTestTransportAdapter(self.client)) + self.server.lsp._send_only_body = True + + self.client.lsp.connection_made(PyodideTestTransportAdapter(self.server)) + self.client.lsp._send_only_body = True + + def start(self): + self.initialize() + + def stop(self): + ... + + @classmethod + def decorate(cls): + return pytest.mark.parametrize("client_server", [cls], indirect=True) + + def initialize(self): + response = self.client.lsp.send_request( + INITIALIZE, + InitializeParams( + process_id=12345, root_uri="file://", capabilities=ClientCapabilities() + ), + ).result(timeout=CALL_TIMEOUT) + + assert response.capabilities is not None + + def __iter__(self): + yield self.client + yield self.server + + +class NativeClientServer: + def __init__(self, LS=LanguageServer): + # Client to Server pipe + csr, csw = os.pipe() + # Server to client pipe + scr, scw = os.pipe() + + # Setup Server + self.server = LS("server", "v1") + self.server_thread = threading.Thread( + name="Server Thread", + target=self.server.start_io, + args=(os.fdopen(csr, "rb"), os.fdopen(scw, "wb")), + ) + self.server_thread.daemon = True + + # Setup client + self.client = LS("client", "v1", asyncio.new_event_loop()) + self.client_thread = threading.Thread( + name="Client Thread", + target=self.client.start_io, + args=(os.fdopen(scr, "rb"), os.fdopen(csw, "wb")), + ) + self.client_thread.daemon = True + + @classmethod + def decorate(cls): + return pytest.mark.parametrize("client_server", [cls], indirect=True) + + def start(self): + self.server_thread.start() + self.server.thread_id = self.server_thread.ident + self.client_thread.start() + self.initialize() + + def stop(self): + shutdown_response = self.client.lsp.send_request(SHUTDOWN).result() + assert shutdown_response is None + self.client.lsp.notify(EXIT) + self.server_thread.join() + self.client._stop_event.set() + try: + self.client.loop._signal_handlers.clear() # HACK ? + except AttributeError: + pass + self.client_thread.join() + + @retry_stalled_init_fix_hack() + def initialize(self): + timeout = None if "DISABLE_TIMEOUT" in os.environ else 1 + response = self.client.lsp.send_request( + INITIALIZE, + InitializeParams( + process_id=12345, root_uri="file://", capabilities=ClientCapabilities() + ), + ).result(timeout=timeout) + assert response.capabilities is not None + + def __iter__(self): + yield self.client + yield self.server diff --git a/tests/lsp/__init__.py b/tests/lsp/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/lsp/__init__.py diff --git a/tests/lsp/semantic_tokens/__init__.py b/tests/lsp/semantic_tokens/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/lsp/semantic_tokens/__init__.py diff --git a/tests/lsp/semantic_tokens/test_delta_missing_legend.py b/tests/lsp/semantic_tokens/test_delta_missing_legend.py new file mode 100644 index 0000000..a3069da --- /dev/null +++ b/tests/lsp/semantic_tokens/test_delta_missing_legend.py @@ -0,0 +1,92 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Optional, Union + +from lsprotocol.types import ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA, +) +from lsprotocol.types import ( + SemanticTokens, + SemanticTokensDeltaParams, + SemanticTokensLegend, + SemanticTokensPartialResult, + SemanticTokensOptionsFullType1, + TextDocumentIdentifier, +) + +from ...conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA, + SemanticTokensLegend( + token_types=["keyword", "operator"], token_modifiers=["readonly"] + ), + ) + def f( + params: SemanticTokensDeltaParams, + ) -> Union[SemanticTokensPartialResult, Optional[SemanticTokens]]: + if params.text_document.uri == "file://return.tokens": + return SemanticTokens(data=[0, 0, 3, 0, 0]) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + provider = capabilities.semantic_tokens_provider + assert provider.full == SemanticTokensOptionsFullType1(delta=True) + assert provider.legend.token_types == [ + "keyword", + "operator", + ] + assert provider.legend.token_modifiers == ["readonly"] + + +@ConfiguredLS.decorate() +def test_semantic_tokens_full_delta_return_tokens(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA, + SemanticTokensDeltaParams( + text_document=TextDocumentIdentifier(uri="file://return.tokens"), + previous_result_id="id", + ), + ).result() + + assert response + + assert response.data == [0, 0, 3, 0, 0] + + +@ConfiguredLS.decorate() +def test_semantic_tokens_full_delta_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA, + SemanticTokensDeltaParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + previous_result_id="id", + ), + ).result() + + assert response is None diff --git a/tests/lsp/semantic_tokens/test_delta_missing_legend_none.py b/tests/lsp/semantic_tokens/test_delta_missing_legend_none.py new file mode 100644 index 0000000..6f4fa17 --- /dev/null +++ b/tests/lsp/semantic_tokens/test_delta_missing_legend_none.py @@ -0,0 +1,48 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Optional, Union + +from lsprotocol.types import ( + SemanticTokens, + SemanticTokensDeltaParams, + SemanticTokensPartialResult, +) +from lsprotocol.types import ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA, +) + +from ...conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA) + def f( + params: SemanticTokensDeltaParams, + ) -> Union[SemanticTokensPartialResult, Optional[SemanticTokens]]: + return SemanticTokens(data=[0, 0, 3, 0, 0]) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.semantic_tokens_provider is None + assert capabilities.semantic_tokens_provider is None diff --git a/tests/lsp/semantic_tokens/test_full_missing_legend.py b/tests/lsp/semantic_tokens/test_full_missing_legend.py new file mode 100644 index 0000000..e18dbde --- /dev/null +++ b/tests/lsp/semantic_tokens/test_full_missing_legend.py @@ -0,0 +1,46 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Optional, Union + +from lsprotocol.types import ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL, +) +from lsprotocol.types import ( + SemanticTokens, + SemanticTokensPartialResult, + SemanticTokensParams, +) + +from ...conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL) + def f( + params: SemanticTokensParams, + ) -> Union[SemanticTokensPartialResult, Optional[SemanticTokens]]: + return SemanticTokens(data=[0, 0, 3, 0, 0]) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + assert capabilities.semantic_tokens_provider is None diff --git a/tests/lsp/semantic_tokens/test_range.py b/tests/lsp/semantic_tokens/test_range.py new file mode 100644 index 0000000..a65504b --- /dev/null +++ b/tests/lsp/semantic_tokens/test_range.py @@ -0,0 +1,103 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Optional, Union + +from lsprotocol.types import ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE, +) +from lsprotocol.types import ( + Position, + Range, + SemanticTokens, + SemanticTokensLegend, + SemanticTokensPartialResult, + SemanticTokensRangeParams, + TextDocumentIdentifier, +) + +from ...conftest import ClientServer + +SemanticTokenReturnType = Optional[ + Union[SemanticTokensPartialResult, Optional[SemanticTokens]] +] + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE, + SemanticTokensLegend( + token_types=["keyword", "operator"], token_modifiers=["readonly"] + ), + ) + def f( + params: SemanticTokensRangeParams, + ) -> SemanticTokenReturnType: + if params.text_document.uri == "file://return.tokens": + return SemanticTokens(data=[0, 0, 3, 0, 0]) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + provider = capabilities.semantic_tokens_provider + assert provider.range + assert provider.legend.token_types == [ + "keyword", + "operator", + ] + assert provider.legend.token_modifiers == ["readonly"] + + +@ConfiguredLS.decorate() +def test_semantic_tokens_range_return_tokens(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE, + SemanticTokensRangeParams( + text_document=TextDocumentIdentifier(uri="file://return.tokens"), + range=Range( + start=Position(line=0, character=0), + end=Position(line=10, character=80), + ), + ), + ).result() + + assert response + + assert response.data == [0, 0, 3, 0, 0] + + +@ConfiguredLS.decorate() +def test_semantic_tokens_range_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE, + SemanticTokensRangeParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + range=Range( + start=Position(line=0, character=0), + end=Position(line=10, character=80), + ), + ), + ).result() + + assert response is None diff --git a/tests/lsp/semantic_tokens/test_range_missing_legends.py b/tests/lsp/semantic_tokens/test_range_missing_legends.py new file mode 100644 index 0000000..69780ef --- /dev/null +++ b/tests/lsp/semantic_tokens/test_range_missing_legends.py @@ -0,0 +1,47 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Optional, Union + +from lsprotocol.types import ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE, +) +from lsprotocol.types import ( + SemanticTokens, + SemanticTokensParams, + SemanticTokensPartialResult, +) + +from ...conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE) + def f( + params: SemanticTokensParams, + ) -> Union[SemanticTokensPartialResult, Optional[SemanticTokens]]: + return SemanticTokens(data=[0, 0, 3, 0, 0]) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.semantic_tokens_provider is None diff --git a/tests/lsp/semantic_tokens/test_semantic_tokens_full.py b/tests/lsp/semantic_tokens/test_semantic_tokens_full.py new file mode 100644 index 0000000..dba9fa6 --- /dev/null +++ b/tests/lsp/semantic_tokens/test_semantic_tokens_full.py @@ -0,0 +1,93 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Optional, Union + +from lsprotocol.types import ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL, +) +from lsprotocol.types import ( + SemanticTokens, + SemanticTokensLegend, + SemanticTokensParams, + SemanticTokensPartialResult, + TextDocumentIdentifier, +) + +from ...conftest import ClientServer + +SemanticTokenReturnType = Optional[ + Union[SemanticTokensPartialResult, Optional[SemanticTokens]] +] + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL, + SemanticTokensLegend( + token_types=["keyword", "operator"], token_modifiers=["readonly"] + ), + ) + def f( + params: SemanticTokensParams, + ) -> SemanticTokenReturnType: + if params.text_document.uri == "file://return.tokens": + return SemanticTokens(data=[0, 0, 3, 0, 0]) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + provider = capabilities.semantic_tokens_provider + assert provider.full + assert provider.legend.token_types == [ + "keyword", + "operator", + ] + assert provider.legend.token_modifiers == ["readonly"] + + +@ConfiguredLS.decorate() +def test_semantic_tokens_full_return_tokens(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL, + SemanticTokensParams( + text_document=TextDocumentIdentifier(uri="file://return.tokens") + ), + ).result() + + assert response + + assert response.data == [0, 0, 3, 0, 0] + + +@ConfiguredLS.decorate() +def test_semantic_tokens_full_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL, + SemanticTokensParams( + text_document=TextDocumentIdentifier(uri="file://return.none") + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_call_hierarchy.py b/tests/lsp/test_call_hierarchy.py new file mode 100644 index 0000000..410a982 --- /dev/null +++ b/tests/lsp/test_call_hierarchy.py @@ -0,0 +1,192 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import List, Optional + +from lsprotocol.types import ( + CALL_HIERARCHY_INCOMING_CALLS, + CALL_HIERARCHY_OUTGOING_CALLS, + TEXT_DOCUMENT_PREPARE_CALL_HIERARCHY, +) +from lsprotocol.types import ( + CallHierarchyIncomingCall, + CallHierarchyIncomingCallsParams, + CallHierarchyItem, + CallHierarchyOptions, + CallHierarchyOutgoingCall, + CallHierarchyOutgoingCallsParams, + CallHierarchyPrepareParams, + Position, + Range, + SymbolKind, + SymbolTag, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + +CALL_HIERARCHY_ITEM = CallHierarchyItem( + name="test_name", + kind=SymbolKind.File, + uri="test_uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + selection_range=Range( + start=Position(line=1, character=1), + end=Position(line=2, character=2), + ), + tags=[SymbolTag.Deprecated], + detail="test_detail", + data="test_data", +) + + +def check_call_hierarchy_item_response(item): + assert item.name == "test_name" + assert item.kind == SymbolKind.File + assert item.uri == "test_uri" + assert item.range.start.line == 0 + assert item.range.start.character == 0 + assert item.range.end.line == 1 + assert item.range.end.character == 1 + assert item.selection_range.start.line == 1 + assert item.selection_range.start.character == 1 + assert item.selection_range.end.line == 2 + assert item.selection_range.end.character == 2 + assert len(item.tags) == 1 + assert item.tags[0] == SymbolTag.Deprecated + assert item.detail == "test_detail" + assert item.data == "test_data" + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_PREPARE_CALL_HIERARCHY, + CallHierarchyOptions(), + ) + def f1(params: CallHierarchyPrepareParams) -> Optional[List[CallHierarchyItem]]: + if params.text_document.uri == "file://return.list": + return [CALL_HIERARCHY_ITEM] + else: + return None + + @self.server.feature(CALL_HIERARCHY_INCOMING_CALLS) + def f2( + params: CallHierarchyIncomingCallsParams, + ) -> Optional[List[CallHierarchyIncomingCall]]: + return [ + CallHierarchyIncomingCall( + from_=params.item, + from_ranges=[ + Range( + start=Position(line=2, character=2), + end=Position(line=3, character=3), + ), + ], + ), + ] + + @self.server.feature(CALL_HIERARCHY_OUTGOING_CALLS) + def f3( + params: CallHierarchyOutgoingCallsParams, + ) -> Optional[List[CallHierarchyOutgoingCall]]: + return [ + CallHierarchyOutgoingCall( + to=params.item, + from_ranges=[ + Range( + start=Position(line=3, character=3), + end=Position(line=4, character=4), + ), + ], + ), + ] + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + assert capabilities.call_hierarchy_provider + + +@ConfiguredLS.decorate() +def test_call_hierarchy_prepare_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_PREPARE_CALL_HIERARCHY, + CallHierarchyPrepareParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + position=Position(line=0, character=0), + ), + ).result() + + check_call_hierarchy_item_response(response[0]) + + +@ConfiguredLS.decorate() +def test_call_hierarchy_prepare_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_PREPARE_CALL_HIERARCHY, + CallHierarchyPrepareParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None + + +@ConfiguredLS.decorate() +def test_call_hierarchy_incoming_calls_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + CALL_HIERARCHY_INCOMING_CALLS, + CallHierarchyIncomingCallsParams(item=CALL_HIERARCHY_ITEM), + ).result() + + item = response[0] + + check_call_hierarchy_item_response(item.from_) + + assert item.from_ranges[0].start.line == 2 + assert item.from_ranges[0].start.character == 2 + assert item.from_ranges[0].end.line == 3 + assert item.from_ranges[0].end.character == 3 + + +@ConfiguredLS.decorate() +def test_call_hierarchy_outgoing_calls_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + CALL_HIERARCHY_OUTGOING_CALLS, + CallHierarchyOutgoingCallsParams(item=CALL_HIERARCHY_ITEM), + ).result() + + item = response[0] + + check_call_hierarchy_item_response(item.to) + + assert item.from_ranges[0].start.line == 3 + assert item.from_ranges[0].start.character == 3 + assert item.from_ranges[0].end.line == 4 + assert item.from_ranges[0].end.character == 4 diff --git a/tests/lsp/test_code_action.py b/tests/lsp/test_code_action.py new file mode 100644 index 0000000..c10dcd5 --- /dev/null +++ b/tests/lsp/test_code_action.py @@ -0,0 +1,60 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Tuple + +from lsprotocol import types + +from ..client import LanguageClient + + +async def test_code_actions( + code_action_client: Tuple[LanguageClient, types.InitializeResult], uri_for +): + """Ensure that the example code action server is working as expected.""" + client, initialize_result = code_action_client + + code_action_options = initialize_result.capabilities.code_action_provider + assert code_action_options.code_action_kinds == [types.CodeActionKind.QuickFix] + + test_uri = uri_for("sums.txt") + assert test_uri is not None + + response = await client.text_document_code_action_async( + types.CodeActionParams( + text_document=types.TextDocumentIdentifier(uri=test_uri), + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=1, character=0), + ), + context=types.CodeActionContext(diagnostics=[]), + ) + ) + + assert len(response) == 1 + code_action = response[0] + + assert code_action.title == "Evaluate '1 + 1 ='" + assert code_action.kind == types.CodeActionKind.QuickFix + + fix = code_action.edit.changes[test_uri][0] + expected_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=7), + ) + + assert fix.range == expected_range + assert fix.new_text == "1 + 1 = 2!" diff --git a/tests/lsp/test_code_lens.py b/tests/lsp/test_code_lens.py new file mode 100644 index 0000000..7024110 --- /dev/null +++ b/tests/lsp/test_code_lens.py @@ -0,0 +1,94 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_CODE_LENS +from lsprotocol.types import ( + CodeLens, + CodeLensOptions, + CodeLensParams, + Command, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_CODE_LENS, + CodeLensOptions(resolve_provider=False), + ) + def f(params: CodeLensParams) -> Optional[List[CodeLens]]: + if params.text_document.uri == "file://return.list": + return [ + CodeLens( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + command=Command( + title="cmd1", + command="cmd1", + ), + data="some data", + ), + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.code_lens_provider + assert not capabilities.code_lens_provider.resolve_provider + + +@ConfiguredLS.decorate() +def test_code_lens_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_CODE_LENS, + CodeLensParams(text_document=TextDocumentIdentifier(uri="file://return.list")), + ).result() + + assert response[0].data == "some data" + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + assert response[0].command.title == "cmd1" + assert response[0].command.command == "cmd1" + + +@ConfiguredLS.decorate() +def test_code_lens_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_CODE_LENS, + CodeLensParams(text_document=TextDocumentIdentifier(uri="file://return.none")), + ).result() + + assert response is None diff --git a/tests/lsp/test_color_presentation.py b/tests/lsp/test_color_presentation.py new file mode 100644 index 0000000..6748e66 --- /dev/null +++ b/tests/lsp/test_color_presentation.py @@ -0,0 +1,103 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List + +from lsprotocol.types import TEXT_DOCUMENT_COLOR_PRESENTATION +from lsprotocol.types import ( + Color, + ColorPresentation, + ColorPresentationParams, + Position, + Range, + TextDocumentIdentifier, + TextEdit, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(TEXT_DOCUMENT_COLOR_PRESENTATION) + def f(params: ColorPresentationParams) -> List[ColorPresentation]: + return [ + ColorPresentation( + label="label1", + text_edit=TextEdit( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + new_text="te", + ), + additional_text_edits=[ + TextEdit( + range=Range( + start=Position(line=1, character=1), + end=Position(line=2, character=2), + ), + new_text="ate1", + ), + TextEdit( + range=Range( + start=Position(line=2, character=2), + end=Position(line=3, character=3), + ), + new_text="ate2", + ), + ], + ) + ] + + +@ConfiguredLS.decorate() +def test_color_presentation(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_COLOR_PRESENTATION, + ColorPresentationParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + color=Color(red=0.6, green=0.2, blue=0.3, alpha=0.5), + range=Range( + start=Position(line=0, character=0), + end=Position(line=3, character=3), + ), + ), + ).result() + + assert response[0].label == "label1" + assert response[0].text_edit.new_text == "te" + + assert response[0].text_edit.range.start.line == 0 + assert response[0].text_edit.range.start.character == 0 + assert response[0].text_edit.range.end.line == 1 + assert response[0].text_edit.range.end.character == 1 + + range = response[0].additional_text_edits[0].range + assert range.start.line == 1 + assert range.start.character == 1 + assert range.end.line == 2 + assert range.end.character == 2 + + range = response[0].additional_text_edits[1].range + assert range.start.line == 2 + assert range.start.character == 2 + assert range.end.line == 3 + assert range.end.character == 3 diff --git a/tests/lsp/test_completion.py b/tests/lsp/test_completion.py new file mode 100644 index 0000000..dcb8124 --- /dev/null +++ b/tests/lsp/test_completion.py @@ -0,0 +1,48 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Tuple + +from lsprotocol import types + + +from ..client import LanguageClient + + +async def test_completion( + json_server_client: Tuple[LanguageClient, types.InitializeResult], + uri_for, +): + """Ensure that the completion methods are working as expected.""" + client, initialize_result = json_server_client + + completion_provider = initialize_result.capabilities.completion_provider + assert completion_provider + assert completion_provider.trigger_characters == [","] + assert completion_provider.all_commit_characters == [":"] + + test_uri = uri_for("example.json") + assert test_uri is not None + + response = await client.text_document_completion_async( + types.CompletionParams( + text_document=types.TextDocumentIdentifier(uri=test_uri), + position=types.Position(line=0, character=0), + ) + ) + + labels = {i.label for i in response.items} + assert labels == set(['"', "[", "]", "{", "}"]) diff --git a/tests/lsp/test_declaration.py b/tests/lsp/test_declaration.py new file mode 100644 index 0000000..221982d --- /dev/null +++ b/tests/lsp/test_declaration.py @@ -0,0 +1,161 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional, Union + +from lsprotocol.types import TEXT_DOCUMENT_DECLARATION +from lsprotocol.types import ( + DeclarationOptions, + DeclarationParams, + Location, + LocationLink, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(TEXT_DOCUMENT_DECLARATION, DeclarationOptions()) + def f( + params: DeclarationParams, + ) -> Optional[Union[Location, List[Location], List[LocationLink]]]: + location = Location( + uri="uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ) + + location_link = LocationLink( + target_uri="uri", + target_range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + target_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=2, character=2), + ), + origin_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=3, character=3), + ), + ) + + return { # type: ignore + "file://return.location": location, + "file://return.location_list": [location], + "file://return.location_link_list": [location_link], + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.declaration_provider + + +@ConfiguredLS.decorate() +def test_declaration_return_location(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DECLARATION, + DeclarationParams( + text_document=TextDocumentIdentifier(uri="file://return.location"), + position=Position(line=0, character=0), + ), + ).result() + + assert response.uri == "uri" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_declaration_return_location_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DECLARATION, + DeclarationParams( + text_document=TextDocumentIdentifier(uri="file://return.location_list"), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].uri == "uri" + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_declaration_return_location_link_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DECLARATION, + DeclarationParams( + text_document=TextDocumentIdentifier( + uri="file://return.location_link_list" + ), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].target_uri == "uri" + + assert response[0].target_range.start.line == 0 + assert response[0].target_range.start.character == 0 + assert response[0].target_range.end.line == 1 + assert response[0].target_range.end.character == 1 + + assert response[0].target_selection_range.start.line == 0 + assert response[0].target_selection_range.start.character == 0 + assert response[0].target_selection_range.end.line == 2 + assert response[0].target_selection_range.end.character == 2 + + assert response[0].origin_selection_range.start.line == 0 + assert response[0].origin_selection_range.start.character == 0 + assert response[0].origin_selection_range.end.line == 3 + assert response[0].origin_selection_range.end.character == 3 + + +@ConfiguredLS.decorate() +def test_declaration_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DECLARATION, + DeclarationParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_definition.py b/tests/lsp/test_definition.py new file mode 100644 index 0000000..3ed2f96 --- /dev/null +++ b/tests/lsp/test_definition.py @@ -0,0 +1,164 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional, Union + +from lsprotocol.types import TEXT_DOCUMENT_DEFINITION +from lsprotocol.types import ( + DefinitionOptions, + DefinitionParams, + Location, + LocationLink, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_DEFINITION, + DefinitionOptions(), + ) + def f( + params: DefinitionParams, + ) -> Optional[Union[Location, List[Location], List[LocationLink]]]: + location = Location( + uri="uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ) + + location_link = LocationLink( + target_uri="uri", + target_range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + target_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=2, character=2), + ), + origin_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=3, character=3), + ), + ) + + return { # type: ignore + "file://return.location": location, + "file://return.location_list": [location], + "file://return.location_link_list": [location_link], + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.definition_provider is not None + + +@ConfiguredLS.decorate() +def test_definition_return_location(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DEFINITION, + DefinitionParams( + text_document=TextDocumentIdentifier(uri="file://return.location"), + position=Position(line=0, character=0), + ), + ).result() + + assert response.uri == "uri" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_definition_return_location_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DEFINITION, + DefinitionParams( + text_document=TextDocumentIdentifier(uri="file://return.location_list"), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].uri == "uri" + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_definition_return_location_link_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DEFINITION, + DefinitionParams( + text_document=TextDocumentIdentifier( + uri="file://return.location_link_list" + ), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].target_uri == "uri" + + assert response[0].target_range.start.line == 0 + assert response[0].target_range.start.character == 0 + assert response[0].target_range.end.line == 1 + assert response[0].target_range.end.character == 1 + + assert response[0].target_selection_range.start.line == 0 + assert response[0].target_selection_range.start.character == 0 + assert response[0].target_selection_range.end.line == 2 + assert response[0].target_selection_range.end.character == 2 + + assert response[0].origin_selection_range.start.line == 0 + assert response[0].origin_selection_range.start.character == 0 + assert response[0].origin_selection_range.end.line == 3 + assert response[0].origin_selection_range.end.character == 3 + + +@ConfiguredLS.decorate() +def test_definition_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DEFINITION, + DefinitionParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_diagnostics.py b/tests/lsp/test_diagnostics.py new file mode 100644 index 0000000..c420942 --- /dev/null +++ b/tests/lsp/test_diagnostics.py @@ -0,0 +1,68 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import json +from typing import Tuple + +from lsprotocol import types + + +from ..client import LanguageClient + + +async def test_diagnostics( + json_server_client: Tuple[LanguageClient, types.InitializeResult], + uri_for, +): + """Ensure that diagnostics are working as expected.""" + client, _ = json_server_client + + test_uri = uri_for("example.json") + assert test_uri is not None + + # Get the expected error message + document_content = "text" + try: + json.loads(document_content) + except json.JSONDecodeError as err: + expected_message = err.msg + + client.text_document_did_open( + types.DidOpenTextDocumentParams( + text_document=types.TextDocumentItem( + uri=test_uri, language_id="json", version=1, text=document_content + ) + ) + ) + + await client.wait_for_notification(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS) + + diagnostics = client.diagnostics[test_uri] + assert diagnostics[0].message == expected_message + + result = await client.text_document_diagnostic_async( + types.DocumentDiagnosticParams( + text_document=types.TextDocumentIdentifier(test_uri) + ) + ) + diagnostics = result.items + assert diagnostics[0].message == expected_message + + workspace_result = await client.workspace_diagnostic_async( + types.WorkspaceDiagnosticParams(previous_result_ids=[]) + ) + diagnostics = workspace_result.items[0].items + assert diagnostics[0].message == expected_message diff --git a/tests/lsp/test_document_color.py b/tests/lsp/test_document_color.py new file mode 100644 index 0000000..460a60b --- /dev/null +++ b/tests/lsp/test_document_color.py @@ -0,0 +1,81 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List + +from lsprotocol.types import TEXT_DOCUMENT_DOCUMENT_COLOR +from lsprotocol.types import ( + Color, + ColorInformation, + DocumentColorOptions, + DocumentColorParams, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_DOCUMENT_COLOR, + DocumentColorOptions(), + ) + def f(params: DocumentColorParams) -> List[ColorInformation]: + return [ + ColorInformation( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + color=Color(red=0.5, green=0.5, blue=0.5, alpha=0.5), + ) + ] + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.color_provider + + +@ConfiguredLS.decorate() +def test_document_color(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_COLOR, + DocumentColorParams( + text_document=TextDocumentIdentifier(uri="file://return.list") + ), + ).result() + + assert response + assert response[0].color.red == 0.5 + assert response[0].color.green == 0.5 + assert response[0].color.blue == 0.5 + assert response[0].color.alpha == 0.5 + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 diff --git a/tests/lsp/test_document_highlight.py b/tests/lsp/test_document_highlight.py new file mode 100644 index 0000000..e4afc5f --- /dev/null +++ b/tests/lsp/test_document_highlight.py @@ -0,0 +1,108 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT +from lsprotocol.types import ( + DocumentHighlight, + DocumentHighlightKind, + DocumentHighlightOptions, + DocumentHighlightParams, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT, + DocumentHighlightOptions(), + ) + def f(params: DocumentHighlightParams) -> Optional[List[DocumentHighlight]]: + if params.text_document.uri == "file://return.list": + return [ + DocumentHighlight( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ), + DocumentHighlight( + range=Range( + start=Position(line=1, character=1), + end=Position(line=2, character=2), + ), + kind=DocumentHighlightKind.Write, + ), + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.document_highlight_provider + + +@ConfiguredLS.decorate() +def test_document_highlight_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT, + DocumentHighlightParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + assert response[0].kind is None + + assert response[1].range.start.line == 1 + assert response[1].range.start.character == 1 + assert response[1].range.end.line == 2 + assert response[1].range.end.character == 2 + assert response[1].kind == DocumentHighlightKind.Write + + +@ConfiguredLS.decorate() +def test_document_highlight_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT, + DocumentHighlightParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_document_link.py b/tests/lsp/test_document_link.py new file mode 100644 index 0000000..0602773 --- /dev/null +++ b/tests/lsp/test_document_link.py @@ -0,0 +1,98 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_DOCUMENT_LINK +from lsprotocol.types import ( + DocumentLink, + DocumentLinkOptions, + DocumentLinkParams, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_DOCUMENT_LINK, + DocumentLinkOptions(resolve_provider=True), + ) + def f(params: DocumentLinkParams) -> Optional[List[DocumentLink]]: + if params.text_document.uri == "file://return.list": + return [ + DocumentLink( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + target="target", + tooltip="tooltip", + data="data", + ), + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.document_link_provider + assert capabilities.document_link_provider.resolve_provider + + +@ConfiguredLS.decorate() +def test_document_link_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_LINK, + DocumentLinkParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + ), + ).result() + + assert response + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + assert response[0].target == "target" + assert response[0].tooltip == "tooltip" + assert response[0].data == "data" + + +@ConfiguredLS.decorate() +def test_document_link_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_LINK, + DocumentLinkParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_document_symbol.py b/tests/lsp/test_document_symbol.py new file mode 100644 index 0000000..251c8fb --- /dev/null +++ b/tests/lsp/test_document_symbol.py @@ -0,0 +1,168 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Union + +from lsprotocol.types import TEXT_DOCUMENT_DOCUMENT_SYMBOL +from lsprotocol.types import ( + DocumentSymbol, + DocumentSymbolOptions, + DocumentSymbolParams, + Location, + Position, + Range, + SymbolInformation, + SymbolKind, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_DOCUMENT_SYMBOL, + DocumentSymbolOptions(), + ) + def f( + params: DocumentSymbolParams, + ) -> Union[List[SymbolInformation], List[DocumentSymbol]]: + symbol_info = SymbolInformation( + name="symbol", + kind=SymbolKind.Namespace, + location=Location( + uri="uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ), + container_name="container", + deprecated=False, + ) + + document_symbol_inner = DocumentSymbol( + name="inner_symbol", + kind=SymbolKind.Number, + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ) + + document_symbol = DocumentSymbol( + name="symbol", + kind=SymbolKind.Object, + range=Range( + start=Position(line=0, character=0), + end=Position(line=10, character=10), + ), + selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=10, character=10), + ), + detail="detail", + children=[document_symbol_inner], + deprecated=True, + ) + + return { # type: ignore + "file://return.symbol_information_list": [symbol_info], + "file://return.document_symbol_list": [document_symbol], + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.document_symbol_provider + + +@ConfiguredLS.decorate() +def test_document_symbol_return_symbol_information_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_SYMBOL, + DocumentSymbolParams( + text_document=TextDocumentIdentifier( + uri="file://return.symbol_information_list" + ), + ), + ).result() + + assert response + + assert response[0].name == "symbol" + assert response[0].kind == SymbolKind.Namespace + assert response[0].location.uri == "uri" + assert response[0].location.range.start.line == 0 + assert response[0].location.range.start.character == 0 + assert response[0].location.range.end.line == 1 + assert response[0].location.range.end.character == 1 + assert response[0].container_name == "container" + assert not response[0].deprecated + + +@ConfiguredLS.decorate() +def test_document_symbol_return_document_symbol_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_DOCUMENT_SYMBOL, + DocumentSymbolParams( + text_document=TextDocumentIdentifier( + uri="file://return.document_symbol_list" + ), + ), + ).result() + + assert response + + assert response[0].name == "symbol" + assert response[0].kind == SymbolKind.Object + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 10 + assert response[0].range.end.character == 10 + assert response[0].selection_range.start.line == 0 + assert response[0].selection_range.start.character == 0 + assert response[0].selection_range.end.line == 10 + assert response[0].selection_range.end.character == 10 + assert response[0].detail == "detail" + assert response[0].deprecated + + assert response[0].children[0].name == "inner_symbol" + assert response[0].children[0].kind == SymbolKind.Number + assert response[0].children[0].range.start.line == 0 + assert response[0].children[0].range.start.character == 0 + assert response[0].children[0].range.end.line == 1 + assert response[0].children[0].range.end.character == 1 + range = response[0].children[0].selection_range + assert range.start.line == 0 + assert range.start.character == 0 + assert range.end.line == 1 + assert range.end.character == 1 + + assert response[0].children[0].children is None diff --git a/tests/lsp/test_errors.py b/tests/lsp/test_errors.py new file mode 100644 index 0000000..1fbc05b --- /dev/null +++ b/tests/lsp/test_errors.py @@ -0,0 +1,135 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Any, List, Union +import time + +import pytest + +from pygls.exceptions import JsonRpcInternalError, PyglsError, JsonRpcException +from lsprotocol.types import WINDOW_SHOW_MESSAGE, MessageType +from pygls.server import LanguageServer + +from ..conftest import ClientServer + +ERROR_TRIGGER = "test/triggerError" +ERROR_MESSAGE = "Testing errors" + + +class CustomLanguageServerSafe(LanguageServer): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + pass + + +class CustomLanguageServerPotentialRecursion(LanguageServer): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + raise Exception() + + +class CustomLanguageServerSendAll(LanguageServer): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + self.show_message(self.default_error_message, msg_type=MessageType.Error) + + +class ConfiguredLS(ClientServer): + def __init__(self, LS=LanguageServer): + super().__init__(LS) + self.init() + + def init(self): + self.client.messages: List[str] = [] + + @self.server.feature(ERROR_TRIGGER) + def f1(params: Any): + raise Exception(ERROR_MESSAGE) + + @self.client.feature(WINDOW_SHOW_MESSAGE) + def f2(params: Any): + self.client.messages.append(params.message) + + +class CustomConfiguredLSSafe(ConfiguredLS): + def __init__(self): + super().__init__(CustomLanguageServerSafe) + + +class CustomConfiguredLSPotentialRecusrion(ConfiguredLS): + def __init__(self): + super().__init__(CustomLanguageServerPotentialRecursion) + + +class CustomConfiguredLSSendAll(ConfiguredLS): + def __init__(self): + super().__init__(CustomLanguageServerSendAll) + + +@ConfiguredLS.decorate() +def test_request_error_reporting_default(client_server): + client, _ = client_server + assert len(client.messages) == 0 + + with pytest.raises(JsonRpcInternalError, match=ERROR_MESSAGE): + client.lsp.send_request(ERROR_TRIGGER).result() + + time.sleep(0.1) + assert len(client.messages) == 0 + + +@CustomConfiguredLSSendAll.decorate() +def test_request_error_reporting_override(client_server): + client, _ = client_server + assert len(client.messages) == 0 + + with pytest.raises(JsonRpcInternalError, match=ERROR_MESSAGE): + client.lsp.send_request(ERROR_TRIGGER).result() + + time.sleep(0.1) + assert len(client.messages) == 1 + + +@ConfiguredLS.decorate() +def test_notification_error_reporting(client_server): + client, _ = client_server + client.lsp.notify(ERROR_TRIGGER) + time.sleep(0.1) + + assert len(client.messages) == 1 + assert client.messages[0] == LanguageServer.default_error_message + + +@CustomConfiguredLSSafe.decorate() +def test_overriding_error_reporting(client_server): + client, _ = client_server + client.lsp.notify(ERROR_TRIGGER) + time.sleep(0.1) + + assert len(client.messages) == 0 + + +@CustomConfiguredLSPotentialRecusrion.decorate() +def test_overriding_error_reporting_with_potential_recursion(client_server): + client, _ = client_server + client.lsp.notify(ERROR_TRIGGER) + time.sleep(0.1) + + assert len(client.messages) == 0 diff --git a/tests/lsp/test_folding_range.py b/tests/lsp/test_folding_range.py new file mode 100644 index 0000000..8f9c749 --- /dev/null +++ b/tests/lsp/test_folding_range.py @@ -0,0 +1,92 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_FOLDING_RANGE +from lsprotocol.types import ( + FoldingRange, + FoldingRangeKind, + FoldingRangeOptions, + FoldingRangeParams, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_FOLDING_RANGE, + FoldingRangeOptions(), + ) + def f(params: FoldingRangeParams) -> Optional[List[FoldingRange]]: + if params.text_document.uri == "file://return.list": + return [ + FoldingRange( + start_line=0, + end_line=0, + start_character=1, + end_character=1, + kind=FoldingRangeKind.Comment, + ), + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.folding_range_provider + + +@ConfiguredLS.decorate() +def test_folding_range_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_FOLDING_RANGE, + FoldingRangeParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + ), + ).result() + + assert response + + assert response[0].start_line == 0 + assert response[0].end_line == 0 + assert response[0].start_character == 1 + assert response[0].end_character == 1 + assert response[0].kind == FoldingRangeKind.Comment + + +@ConfiguredLS.decorate() +def test_folding_range_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_FOLDING_RANGE, + FoldingRangeParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_formatting.py b/tests/lsp/test_formatting.py new file mode 100644 index 0000000..0c3fbd6 --- /dev/null +++ b/tests/lsp/test_formatting.py @@ -0,0 +1,108 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_FORMATTING +from lsprotocol.types import ( + DocumentFormattingOptions, + DocumentFormattingParams, + FormattingOptions, + Position, + Range, + TextDocumentIdentifier, + TextEdit, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_FORMATTING, + DocumentFormattingOptions(), + ) + def f(params: DocumentFormattingParams) -> Optional[List[TextEdit]]: + if params.text_document.uri == "file://return.list": + return [ + TextEdit( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + new_text="text", + ) + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.document_formatting_provider + + +@ConfiguredLS.decorate() +def test_document_formatting_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_FORMATTING, + DocumentFormattingParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + options=FormattingOptions( + tab_size=2, + insert_spaces=True, + trim_trailing_whitespace=True, + insert_final_newline=True, + trim_final_newlines=True, + ), + ), + ).result() + + assert response + + assert response[0].new_text == "text" + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_document_formatting_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_FORMATTING, + DocumentFormattingParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + options=FormattingOptions( + tab_size=2, + insert_spaces=True, + trim_trailing_whitespace=True, + insert_final_newline=True, + trim_final_newlines=True, + ), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_hover.py b/tests/lsp/test_hover.py new file mode 100644 index 0000000..9007c78 --- /dev/null +++ b/tests/lsp/test_hover.py @@ -0,0 +1,149 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Optional + +from lsprotocol.types import TEXT_DOCUMENT_HOVER +from lsprotocol.types import ( + Hover, + HoverOptions, + HoverParams, + MarkedString_Type1, + MarkupContent, + MarkupKind, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_HOVER, + HoverOptions(), + ) + def f(params: HoverParams) -> Optional[Hover]: + range = Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ) + + return { + "file://return.marked_string": Hover( + range=range, + contents=MarkedString_Type1( + language="language", + value="value", + ), + ), + "file://return.marked_string_list": Hover( + range=range, + contents=[ + MarkedString_Type1( + language="language", + value="value", + ), + "str type", + ], + ), + "file://return.markup_content": Hover( + range=range, + contents=MarkupContent(kind=MarkupKind.Markdown, value="value"), + ), + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.hover_provider + + +@ConfiguredLS.decorate() +def test_hover_return_marked_string(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_HOVER, + HoverParams( + text_document=TextDocumentIdentifier(uri="file://return.marked_string"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response.contents.language == "language" + assert response.contents.value == "value" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_hover_return_marked_string_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_HOVER, + HoverParams( + text_document=TextDocumentIdentifier( + uri="file://return.marked_string_list" + ), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response.contents[0].language == "language" + assert response.contents[0].value == "value" + assert response.contents[1] == "str type" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_hover_return_markup_content(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_HOVER, + HoverParams( + text_document=TextDocumentIdentifier(uri="file://return.markup_content"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response.contents.kind == MarkupKind.Markdown + assert response.contents.value == "value" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 diff --git a/tests/lsp/test_implementation.py b/tests/lsp/test_implementation.py new file mode 100644 index 0000000..4fea3a9 --- /dev/null +++ b/tests/lsp/test_implementation.py @@ -0,0 +1,164 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional, Union + +from lsprotocol.types import TEXT_DOCUMENT_IMPLEMENTATION +from lsprotocol.types import ( + ImplementationOptions, + ImplementationParams, + Location, + LocationLink, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_IMPLEMENTATION, + ImplementationOptions(), + ) + def f( + params: ImplementationParams, + ) -> Optional[Union[Location, List[Location], List[LocationLink]]]: + location = Location( + uri="uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ) + + location_link = LocationLink( + target_uri="uri", + target_range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + target_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=2, character=2), + ), + origin_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=3, character=3), + ), + ) + + return { # type: ignore + "file://return.location": location, + "file://return.location_list": [location], + "file://return.location_link_list": [location_link], + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.implementation_provider + + +@ConfiguredLS.decorate() +def test_type_definition_return_location(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_IMPLEMENTATION, + ImplementationParams( + text_document=TextDocumentIdentifier(uri="file://return.location"), + position=Position(line=0, character=0), + ), + ).result() + + assert response.uri == "uri" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_type_definition_return_location_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_IMPLEMENTATION, + ImplementationParams( + text_document=TextDocumentIdentifier(uri="file://return.location_list"), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].uri == "uri" + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_type_definition_return_location_link_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_IMPLEMENTATION, + ImplementationParams( + text_document=TextDocumentIdentifier( + uri="file://return.location_link_list" + ), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].target_uri == "uri" + + assert response[0].target_range.start.line == 0 + assert response[0].target_range.start.character == 0 + assert response[0].target_range.end.line == 1 + assert response[0].target_range.end.character == 1 + + assert response[0].target_selection_range.start.line == 0 + assert response[0].target_selection_range.start.character == 0 + assert response[0].target_selection_range.end.line == 2 + assert response[0].target_selection_range.end.character == 2 + + assert response[0].origin_selection_range.start.line == 0 + assert response[0].origin_selection_range.start.character == 0 + assert response[0].origin_selection_range.end.line == 3 + assert response[0].origin_selection_range.end.character == 3 + + +@ConfiguredLS.decorate() +def test_type_definition_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_IMPLEMENTATION, + ImplementationParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_inlay_hints.py b/tests/lsp/test_inlay_hints.py new file mode 100644 index 0000000..1146e1b --- /dev/null +++ b/tests/lsp/test_inlay_hints.py @@ -0,0 +1,56 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Tuple + +from lsprotocol import types + +from ..client import LanguageClient + + +async def test_code_actions( + inlay_hints_client: Tuple[LanguageClient, types.InitializeResult], uri_for +): + """Ensure that the example code action server is working as expected.""" + client, initialize_result = inlay_hints_client + + inlay_hint_provider = initialize_result.capabilities.inlay_hint_provider + assert inlay_hint_provider.resolve_provider is True + + test_uri = uri_for("sums.txt") + assert test_uri is not None + + response = await client.text_document_inlay_hint_async( + types.InlayHintParams( + text_document=types.TextDocumentIdentifier(uri=test_uri), + range=types.Range( + start=types.Position(line=3, character=0), + end=types.Position(line=4, character=0), + ), + ) + ) + + assert len(response) == 2 + two, three = response[0], response[1] + + assert two.label == ":10" + assert two.tooltip is None + + assert three.label == ":11" + assert three.tooltip is None + + resolved = await client.inlay_hint_resolve_async(three) + assert resolved.tooltip == "Binary representation of the number: 3" diff --git a/tests/lsp/test_inline_value.py b/tests/lsp/test_inline_value.py new file mode 100644 index 0000000..68d682d --- /dev/null +++ b/tests/lsp/test_inline_value.py @@ -0,0 +1,60 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import Tuple + +from lsprotocol import types + + +from ..client import LanguageClient + + +async def test_inline_value( + json_server_client: Tuple[LanguageClient, types.InitializeResult], + uri_for, +): + """Ensure that inline values are working as expected.""" + client, _ = json_server_client + + test_uri = uri_for("example.json") + assert test_uri is not None + + document_content = '{\n"foo": "bar"\n}' + client.text_document_did_open( + types.DidOpenTextDocumentParams( + text_document=types.TextDocumentItem( + uri=test_uri, language_id="json", version=1, text=document_content + ) + ) + ) + + result = await client.text_document_inline_value_async( + types.InlineValueParams( + text_document=types.TextDocumentIdentifier(test_uri), + range=types.Range( + start=types.Position(line=1, character=0), + end=types.Position(line=1, character=6), + ), + context=types.InlineValueContext( + frame_id=1, + stopped_location=types.Range( + start=types.Position(line=1, character=0), + end=types.Position(line=1, character=6), + ), + ), + ) + ) + assert result[0].text == "Inline value" diff --git a/tests/lsp/test_linked_editing_range.py b/tests/lsp/test_linked_editing_range.py new file mode 100644 index 0000000..2650b7e --- /dev/null +++ b/tests/lsp/test_linked_editing_range.py @@ -0,0 +1,103 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Optional + +from lsprotocol.types import TEXT_DOCUMENT_LINKED_EDITING_RANGE +from lsprotocol.types import ( + LinkedEditingRangeOptions, + LinkedEditingRangeParams, + LinkedEditingRanges, + Position, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_LINKED_EDITING_RANGE, + LinkedEditingRangeOptions(), + ) + def f(params: LinkedEditingRangeParams) -> Optional[LinkedEditingRanges]: + if params.text_document.uri == "file://return.ranges": + return LinkedEditingRanges( + ranges=[ + Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + Range( + start=Position(line=1, character=1), + end=Position(line=2, character=2), + ), + ], + word_pattern="pattern", + ) + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.linked_editing_range_provider + + +@ConfiguredLS.decorate() +def test_linked_editing_ranges_return_ranges(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_LINKED_EDITING_RANGE, + LinkedEditingRangeParams( + text_document=TextDocumentIdentifier(uri="file://return.ranges"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response.ranges[0].start.line == 0 + assert response.ranges[0].start.character == 0 + assert response.ranges[0].end.line == 1 + assert response.ranges[0].end.character == 1 + assert response.ranges[1].start.line == 1 + assert response.ranges[1].start.character == 1 + assert response.ranges[1].end.line == 2 + assert response.ranges[1].end.character == 2 + assert response.word_pattern == "pattern" + + +@ConfiguredLS.decorate() +def test_linked_editing_ranges_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_LINKED_EDITING_RANGE, + LinkedEditingRangeParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_moniker.py b/tests/lsp/test_moniker.py new file mode 100644 index 0000000..09b962d --- /dev/null +++ b/tests/lsp/test_moniker.py @@ -0,0 +1,94 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_MONIKER +from lsprotocol.types import ( + Moniker, + MonikerKind, + MonikerOptions, + MonikerParams, + Position, + TextDocumentIdentifier, + UniquenessLevel, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_MONIKER, + MonikerOptions(), + ) + def f(params: MonikerParams) -> Optional[List[Moniker]]: + if params.text_document.uri == "file://return.list": + return [ + Moniker( + scheme="test_scheme", + identifier="test_identifier", + unique=UniquenessLevel.Global, + kind=MonikerKind.Local, + ), + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.moniker_provider + + +@ConfiguredLS.decorate() +def test_moniker_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_MONIKER, + MonikerParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response[0].scheme == "test_scheme" + assert response[0].identifier == "test_identifier" + assert response[0].unique == UniquenessLevel.Global + assert response[0].kind == MonikerKind.Local + + +@ConfiguredLS.decorate() +def test_references_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_MONIKER, + MonikerParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_on_type_formatting.py b/tests/lsp/test_on_type_formatting.py new file mode 100644 index 0000000..2e7adc9 --- /dev/null +++ b/tests/lsp/test_on_type_formatting.py @@ -0,0 +1,122 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_ON_TYPE_FORMATTING +from lsprotocol.types import ( + DocumentOnTypeFormattingOptions, + DocumentOnTypeFormattingParams, + FormattingOptions, + Position, + Range, + TextDocumentIdentifier, + TextEdit, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_ON_TYPE_FORMATTING, + DocumentOnTypeFormattingOptions( + first_trigger_character=":", + more_trigger_character=[",", "."], + ), + ) + def f(params: DocumentOnTypeFormattingParams) -> Optional[List[TextEdit]]: + if params.text_document.uri == "file://return.list": + return [ + TextEdit( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + new_text="text", + ) + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.document_on_type_formatting_provider + assert ( + capabilities.document_on_type_formatting_provider.first_trigger_character == ":" + ) + assert capabilities.document_on_type_formatting_provider.more_trigger_character == [ + ",", + ".", + ] + + +@ConfiguredLS.decorate() +def test_on_type_formatting_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_ON_TYPE_FORMATTING, + DocumentOnTypeFormattingParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + position=Position(line=0, character=0), + ch=":", + options=FormattingOptions( + tab_size=2, + insert_spaces=True, + trim_trailing_whitespace=True, + insert_final_newline=True, + trim_final_newlines=True, + ), + ), + ).result() + + assert response + + assert response[0].new_text == "text" + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_on_type_formatting_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_ON_TYPE_FORMATTING, + DocumentOnTypeFormattingParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ch=":", + options=FormattingOptions( + tab_size=2, + insert_spaces=True, + trim_trailing_whitespace=True, + insert_final_newline=True, + trim_final_newlines=True, + ), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_prepare_rename.py b/tests/lsp/test_prepare_rename.py new file mode 100644 index 0000000..39d0712 --- /dev/null +++ b/tests/lsp/test_prepare_rename.py @@ -0,0 +1,111 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Optional, Union + +from lsprotocol.types import TEXT_DOCUMENT_PREPARE_RENAME +from lsprotocol.types import ( + Position, + PrepareRenameResult, + PrepareRenameResult_Type1, + PrepareRenameParams, + Range, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(TEXT_DOCUMENT_PREPARE_RENAME) + def f( + params: PrepareRenameParams, + ) -> Optional[Union[Range, PrepareRenameResult]]: + return { # type: ignore + "file://return.range": Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + "file://return.prepare_rename": PrepareRenameResult_Type1( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + placeholder="placeholder", + ), + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + pass + + +@ConfiguredLS.decorate() +def test_prepare_rename_return_range(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_PREPARE_RENAME, + PrepareRenameParams( + text_document=TextDocumentIdentifier(uri="file://return.range"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response.start.line == 0 + assert response.start.character == 0 + assert response.end.line == 1 + assert response.end.character == 1 + + +@ConfiguredLS.decorate() +def test_prepare_rename_return_prepare_rename(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_PREPARE_RENAME, + PrepareRenameParams( + text_document=TextDocumentIdentifier(uri="file://return.prepare_rename"), + position=Position(line=0, character=0), + ), + ).result() + + assert response + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + assert response.placeholder == "placeholder" + + +@ConfiguredLS.decorate() +def test_prepare_rename_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_PREPARE_RENAME, + PrepareRenameParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_progress.py b/tests/lsp/test_progress.py new file mode 100644 index 0000000..4965772 --- /dev/null +++ b/tests/lsp/test_progress.py @@ -0,0 +1,225 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +import asyncio +from typing import List, Optional + +import pytest +from lsprotocol.types import ( + TEXT_DOCUMENT_CODE_LENS, + WINDOW_WORK_DONE_PROGRESS_CANCEL, + WINDOW_WORK_DONE_PROGRESS_CREATE, + PROGRESS, +) +from lsprotocol.types import ( + CodeLens, + CodeLensParams, + CodeLensOptions, + ProgressParams, + TextDocumentIdentifier, + WorkDoneProgressBegin, + WorkDoneProgressEnd, + WorkDoneProgressReport, + WorkDoneProgressCancelParams, + WorkDoneProgressCreateParams, +) +from ..conftest import ClientServer +from pygls import IS_PYODIDE + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + self.client.notifications: List[ProgressParams] = [] + self.client.method_calls: List[WorkDoneProgressCreateParams] = [] + + @self.server.feature( + TEXT_DOCUMENT_CODE_LENS, + CodeLensOptions(resolve_provider=False, work_done_progress=True), + ) + async def f1(params: CodeLensParams) -> Optional[List[CodeLens]]: + if "client_initiated_token" in params.text_document.uri: + token = params.work_done_token + else: + assert "server_initiated_token" in params.text_document.uri + token = params.text_document.uri[len("file://") :] + if "async" in params.text_document.uri: + await self.server.progress.create_async(token) + else: + f = self.server.progress.create(token) + await asyncio.sleep(0.1) + f.result() + + assert token + self.server.lsp.progress.begin( + token, + WorkDoneProgressBegin(kind="begin", title="starting", percentage=0), + ) + await asyncio.sleep(0.1) + if self.server.lsp.progress.tokens[token].cancelled(): + self.server.lsp.progress.end( + token, WorkDoneProgressEnd(kind="end", message="cancelled") + ) + else: + self.server.lsp.progress.report( + token, + WorkDoneProgressReport( + kind="report", message="doing", percentage=50 + ), + ) + self.server.lsp.progress.end( + token, WorkDoneProgressEnd(kind="end", message="done") + ) + return None + + @self.client.feature(PROGRESS) + def f2(params): + self.client.notifications.append(params) + if params.value["kind"] == "begin" and "cancel" in params.token: + # client cancels the progress token + self.client.lsp.notify( + WINDOW_WORK_DONE_PROGRESS_CANCEL, + WorkDoneProgressCancelParams(token=params.token), + ) + + @self.client.feature(WINDOW_WORK_DONE_PROGRESS_CREATE) + def f3(params: WorkDoneProgressCreateParams): + self.client.method_calls.append(params) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + provider = capabilities.code_lens_provider + assert provider + assert provider.work_done_progress + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +@ConfiguredLS.decorate() +async def test_progress_notifications(client_server): + client, _ = client_server + client.lsp.send_request( + TEXT_DOCUMENT_CODE_LENS, + CodeLensParams( + text_document=TextDocumentIdentifier(uri="file://client_initiated_token"), + work_done_token="token", + ), + ).result() + + assert [notif.value for notif in client.notifications] == [ + { + "kind": "begin", + "title": "starting", + "percentage": 0, + }, + { + "kind": "report", + "message": "doing", + "percentage": 50, + }, + {"kind": "end", "message": "done"}, + ] + assert {notif.token for notif in client.notifications} == {"token"} + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +@pytest.mark.parametrize("registration", ("sync", "async")) +@ConfiguredLS.decorate() +async def test_server_initiated_progress_notifications(client_server, registration): + client, _ = client_server + client.lsp.send_request( + TEXT_DOCUMENT_CODE_LENS, + CodeLensParams( + text_document=TextDocumentIdentifier( + uri=f"file://server_initiated_token_{registration}" + ), + work_done_token="token", + ), + ).result() + + assert [notif.value for notif in client.notifications] == [ + { + "kind": "begin", + "title": "starting", + "percentage": 0, + }, + { + "kind": "report", + "message": "doing", + "percentage": 50, + }, + {"kind": "end", "message": "done"}, + ] + assert {notif.token for notif in client.notifications} == { + f"server_initiated_token_{registration}" + } + assert [mc.token for mc in client.method_calls] == [ + f"server_initiated_token_{registration}" + ] + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +@ConfiguredLS.decorate() +def test_progress_cancel_notifications(client_server): + client, _ = client_server + client.lsp.send_request( + TEXT_DOCUMENT_CODE_LENS, + CodeLensParams( + text_document=TextDocumentIdentifier(uri="file://client_initiated_token"), + work_done_token="token_with_cancellation", + ), + ).result() + assert [notif.value for notif in client.notifications] == [ + { + "kind": "begin", + "title": "starting", + "percentage": 0, + }, + {"kind": "end", "message": "cancelled"}, + ] + assert {notif.token for notif in client.notifications} == { + "token_with_cancellation" + } + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +@pytest.mark.parametrize("registration", ("sync", "async")) +@ConfiguredLS.decorate() +def test_server_initiated_progress_progress_cancel_notifications( + client_server, registration +): + client, _ = client_server + client.lsp.send_request( + TEXT_DOCUMENT_CODE_LENS, + CodeLensParams( + text_document=TextDocumentIdentifier( + uri=f"file://server_initiated_token_{registration}_with_cancellation" + ), + ), + ).result() + + assert [notif.value for notif in client.notifications] == [ + { + "kind": "begin", + "title": "starting", + "percentage": 0, + }, + {"kind": "end", "message": "cancelled"}, + ] diff --git a/tests/lsp/test_range_formatting.py b/tests/lsp/test_range_formatting.py new file mode 100644 index 0000000..e7a118c --- /dev/null +++ b/tests/lsp/test_range_formatting.py @@ -0,0 +1,116 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_RANGE_FORMATTING +from lsprotocol.types import ( + DocumentRangeFormattingOptions, + DocumentRangeFormattingParams, + FormattingOptions, + Position, + Range, + TextDocumentIdentifier, + TextEdit, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_RANGE_FORMATTING, + DocumentRangeFormattingOptions(), + ) + def f(params: DocumentRangeFormattingParams) -> Optional[List[TextEdit]]: + if params.text_document.uri == "file://return.list": + return [ + TextEdit( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + new_text="text", + ) + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.document_range_formatting_provider + + +@ConfiguredLS.decorate() +def test_range_formatting_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_RANGE_FORMATTING, + DocumentRangeFormattingParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + options=FormattingOptions( + tab_size=2, + insert_spaces=True, + trim_trailing_whitespace=True, + insert_final_newline=True, + trim_final_newlines=True, + ), + ), + ).result() + + assert response + + assert response[0].new_text == "text" + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_range_formatting_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_RANGE_FORMATTING, + DocumentRangeFormattingParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + options=FormattingOptions( + tab_size=2, + insert_spaces=True, + trim_trailing_whitespace=True, + insert_final_newline=True, + trim_final_newlines=True, + ), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_references.py b/tests/lsp/test_references.py new file mode 100644 index 0000000..5867e35 --- /dev/null +++ b/tests/lsp/test_references.py @@ -0,0 +1,103 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_REFERENCES +from lsprotocol.types import ( + Location, + Position, + Range, + ReferenceContext, + ReferenceOptions, + ReferenceParams, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_REFERENCES, + ReferenceOptions(), + ) + def f(params: ReferenceParams) -> Optional[List[Location]]: + if params.text_document.uri == "file://return.list": + return [ + Location( + uri="uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ), + ] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.references_provider + + +@ConfiguredLS.decorate() +def test_references_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_REFERENCES, + ReferenceParams( + text_document=TextDocumentIdentifier(uri="file://return.list"), + position=Position(line=0, character=0), + context=ReferenceContext( + include_declaration=True, + ), + ), + ).result() + + assert response + + assert response[0].uri == "uri" + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_references_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_REFERENCES, + ReferenceParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + context=ReferenceContext( + include_declaration=True, + ), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_rename.py b/tests/lsp/test_rename.py new file mode 100644 index 0000000..48cface --- /dev/null +++ b/tests/lsp/test_rename.py @@ -0,0 +1,195 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Optional + +from lsprotocol.types import TEXT_DOCUMENT_RENAME +from lsprotocol.types import ( + CreateFile, + CreateFileOptions, + DeleteFile, + DeleteFileOptions, + OptionalVersionedTextDocumentIdentifier, + Position, + Range, + RenameFile, + RenameFileOptions, + RenameOptions, + RenameParams, + ResourceOperationKind, + TextDocumentEdit, + TextDocumentIdentifier, + TextEdit, + WorkspaceEdit, +) + +from ..conftest import ClientServer + +workspace_edit = { + "changes": { + "uri1": [ + TextEdit( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + new_text="text1", + ), + TextEdit( + range=Range( + start=Position(line=1, character=1), + end=Position(line=2, character=2), + ), + new_text="text2", + ), + ], + }, + "document_changes": [ + TextDocumentEdit( + text_document=OptionalVersionedTextDocumentIdentifier( + uri="uri", + version=3, + ), + edits=[ + TextEdit( + range=Range( + start=Position(line=2, character=2), + end=Position(line=3, character=3), + ), + new_text="text3", + ), + ], + ), + CreateFile( + kind=ResourceOperationKind.Create.value, + uri="create file", + options=CreateFileOptions( + overwrite=True, + ignore_if_exists=True, + ), + ), + RenameFile( + kind=ResourceOperationKind.Rename.value, + old_uri="rename old uri", + new_uri="rename new uri", + options=RenameFileOptions( + overwrite=True, + ignore_if_exists=True, + ), + ), + DeleteFile( + kind=ResourceOperationKind.Delete.value, + uri="delete file", + options=DeleteFileOptions( + recursive=True, + ignore_if_not_exists=True, + ), + ), + ], +} + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_RENAME, + RenameOptions(prepare_provider=True), + ) + def f(params: RenameParams) -> Optional[WorkspaceEdit]: + if params.text_document.uri == "file://return.workspace_edit": + return WorkspaceEdit(**workspace_edit) + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.rename_provider + assert capabilities.rename_provider.prepare_provider + + +@ConfiguredLS.decorate() +def test_rename_return_workspace_edit(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_RENAME, + RenameParams( + text_document=TextDocumentIdentifier(uri="file://return.workspace_edit"), + position=Position(line=0, character=0), + new_name="new name", + ), + ).result() + + assert response + + changes = response.changes["uri1"] + assert changes[0].new_text == "text1" + assert changes[0].range.start.line == 0 + assert changes[0].range.start.character == 0 + assert changes[0].range.end.line == 1 + assert changes[0].range.end.character == 1 + + assert changes[1].new_text == "text2" + assert changes[1].range.start.line == 1 + assert changes[1].range.start.character == 1 + assert changes[1].range.end.line == 2 + assert changes[1].range.end.character == 2 + + changes = response.document_changes + assert changes[0].text_document.uri == "uri" + assert changes[0].text_document.version == 3 + assert changes[0].edits[0].new_text == "text3" + assert changes[0].edits[0].range.start.line == 2 + assert changes[0].edits[0].range.start.character == 2 + assert changes[0].edits[0].range.end.line == 3 + assert changes[0].edits[0].range.end.character == 3 + + assert changes[1].kind == ResourceOperationKind.Create.value + assert changes[1].uri == "create file" + assert changes[1].options.ignore_if_exists + assert changes[1].options.overwrite + + assert changes[2].kind == ResourceOperationKind.Rename.value + assert changes[2].new_uri == "rename new uri" + assert changes[2].old_uri == "rename old uri" + assert changes[2].options.ignore_if_exists + assert changes[2].options.overwrite + + assert changes[3].kind == ResourceOperationKind.Delete.value + assert changes[3].uri == "delete file" + assert changes[3].options.ignore_if_not_exists + assert changes[3].options.recursive + + +@ConfiguredLS.decorate() +def test_rename_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_RENAME, + RenameParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + new_name="new name", + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_selection_range.py b/tests/lsp/test_selection_range.py new file mode 100644 index 0000000..5f669a2 --- /dev/null +++ b/tests/lsp/test_selection_range.py @@ -0,0 +1,110 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import List, Optional + +from lsprotocol.types import TEXT_DOCUMENT_SELECTION_RANGE +from lsprotocol.types import ( + Position, + Range, + SelectionRange, + SelectionRangeOptions, + SelectionRangeParams, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_SELECTION_RANGE, + SelectionRangeOptions(), + ) + def f(params: SelectionRangeParams) -> Optional[List[SelectionRange]]: + if params.text_document.uri == "file://return.list": + root = SelectionRange( + range=Range( + start=Position(line=0, character=0), + end=Position(line=10, character=10), + ), + ) + + inner_range = SelectionRange( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + parent=root, + ) + + return [root, inner_range] + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.selection_range_provider + + +@ConfiguredLS.decorate() +def test_selection_range_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SELECTION_RANGE, + SelectionRangeParams( + # query="query", + text_document=TextDocumentIdentifier(uri="file://return.list"), + positions=[Position(line=0, character=0)], + ), + ).result() + + assert response + + root = response[0] + assert root.range.start.line == 0 + assert root.range.start.character == 0 + assert root.range.end.line == 10 + assert root.range.end.character == 10 + assert root.parent is None + + assert response[1].range.start.line == 0 + assert response[1].range.start.character == 0 + assert response[1].range.end.line == 1 + assert response[1].range.end.character == 1 + assert response[1].parent == root + + +@ConfiguredLS.decorate() +def test_selection_range_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SELECTION_RANGE, + SelectionRangeParams( + # query="query", + text_document=TextDocumentIdentifier(uri="file://return.none"), + positions=[Position(line=0, character=0)], + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_signature_help.py b/tests/lsp/test_signature_help.py new file mode 100644 index 0000000..e318120 --- /dev/null +++ b/tests/lsp/test_signature_help.py @@ -0,0 +1,161 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Optional + +import pytest + +from lsprotocol.types import TEXT_DOCUMENT_SIGNATURE_HELP +from lsprotocol.types import ( + ParameterInformation, + Position, + SignatureHelp, + SignatureHelpContext, + SignatureHelpOptions, + SignatureHelpParams, + SignatureHelpTriggerKind, + SignatureInformation, + TextDocumentIdentifier, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_SIGNATURE_HELP, + SignatureHelpOptions( + trigger_characters=["a", "b"], + retrigger_characters=["c", "d"], + ), + ) + def f(params: SignatureHelpParams) -> Optional[SignatureHelp]: + if params.text_document.uri == "file://return.signature_help": + return SignatureHelp( + signatures=[ + SignatureInformation( + label="label", + documentation="documentation", + parameters=[ + ParameterInformation( + label=(0, 0), + documentation="documentation", + ), + ], + ), + ], + active_signature=0, + active_parameter=0, + ) + else: + return None + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + provider = capabilities.signature_help_provider + assert provider + assert provider.trigger_characters == ["a", "b"] + assert provider.retrigger_characters == ["c", "d"] + + +@ConfiguredLS.decorate() +@pytest.mark.skip +def test_signature_help_return_signature_help(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SIGNATURE_HELP, + SignatureHelpParams( + text_document=TextDocumentIdentifier(uri="file://return.signature_help"), + position=Position(line=0, character=0), + context=SignatureHelpContext( + trigger_kind=SignatureHelpTriggerKind.TriggerCharacter, + is_retrigger=True, + trigger_character="a", + active_signature_help=SignatureHelp( + signatures=[ + SignatureInformation( + label="label", + documentation="documentation", + parameters=[ + ParameterInformation( + label=(0, 0), + documentation="documentation", + ), + ], + ), + ], + active_signature=0, + active_parameter=0, + ), + ), + ), + ).result() + + assert response + + assert response["activeParameter"] == 0 + assert response["activeSignature"] == 0 + + assert response["signatures"][0]["label"] == "label" + assert response["signatures"][0]["documentation"] == "documentation" + assert response["signatures"][0]["parameters"][0]["label"] == [0, 0] + assert ( + response["signatures"][0]["parameters"][0]["documentation"] == "documentation" + ) + + +@ConfiguredLS.decorate() +@pytest.mark.skip +def test_signature_help_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_SIGNATURE_HELP, + SignatureHelpParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + context=SignatureHelpContext( + trigger_kind=SignatureHelpTriggerKind.TriggerCharacter, + is_retrigger=True, + trigger_character="a", + active_signature_help=SignatureHelp( + signatures=[ + SignatureInformation( + label="label", + documentation="documentation", + parameters=[ + ParameterInformation( + label=(0, 0), + documentation="documentation", + ), + ], + ), + ], + active_signature=0, + active_parameter=0, + ), + ), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_type_definition.py b/tests/lsp/test_type_definition.py new file mode 100644 index 0000000..b6d3eff --- /dev/null +++ b/tests/lsp/test_type_definition.py @@ -0,0 +1,163 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import List, Optional, Union + +from lsprotocol.types import TEXT_DOCUMENT_TYPE_DEFINITION +from lsprotocol.types import ( + Location, + LocationLink, + Position, + Range, + TextDocumentIdentifier, + TypeDefinitionOptions, + TypeDefinitionParams, +) + +from ..conftest import ClientServer + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature( + TEXT_DOCUMENT_TYPE_DEFINITION, + TypeDefinitionOptions(), + ) + def f( + params: TypeDefinitionParams, + ) -> Optional[Union[Location, List[Location], List[LocationLink]]]: + location = Location( + uri="uri", + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + ) + + location_link = LocationLink( + target_uri="uri", + target_range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=1), + ), + target_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=2, character=2), + ), + origin_selection_range=Range( + start=Position(line=0, character=0), + end=Position(line=3, character=3), + ), + ) + + return { # type: ignore + "file://return.location": location, + "file://return.location_list": [location], + "file://return.location_link_list": [location_link], + }.get(params.text_document.uri, None) + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + + assert capabilities.type_definition_provider + + +@ConfiguredLS.decorate() +def test_type_definition_return_location(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_TYPE_DEFINITION, + TypeDefinitionParams( + text_document=TextDocumentIdentifier(uri="file://return.location"), + position=Position(line=0, character=0), + ), + ).result() + + assert response.uri == "uri" + + assert response.range.start.line == 0 + assert response.range.start.character == 0 + assert response.range.end.line == 1 + assert response.range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_type_definition_return_location_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_TYPE_DEFINITION, + TypeDefinitionParams( + text_document=TextDocumentIdentifier(uri="file://return.location_list"), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].uri == "uri" + + assert response[0].range.start.line == 0 + assert response[0].range.start.character == 0 + assert response[0].range.end.line == 1 + assert response[0].range.end.character == 1 + + +@ConfiguredLS.decorate() +def test_type_definition_return_location_link_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_TYPE_DEFINITION, + TypeDefinitionParams( + text_document=TextDocumentIdentifier( + uri="file://return.location_link_list" + ), + position=Position(line=0, character=0), + ), + ).result() + + assert response[0].target_uri == "uri" + + assert response[0].target_range.start.line == 0 + assert response[0].target_range.start.character == 0 + assert response[0].target_range.end.line == 1 + assert response[0].target_range.end.character == 1 + + assert response[0].target_selection_range.start.line == 0 + assert response[0].target_selection_range.start.character == 0 + assert response[0].target_selection_range.end.line == 2 + assert response[0].target_selection_range.end.character == 2 + + assert response[0].origin_selection_range.start.line == 0 + assert response[0].origin_selection_range.start.character == 0 + assert response[0].origin_selection_range.end.line == 3 + assert response[0].origin_selection_range.end.character == 3 + + +@ConfiguredLS.decorate() +def test_type_definition_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + TEXT_DOCUMENT_TYPE_DEFINITION, + TypeDefinitionParams( + text_document=TextDocumentIdentifier(uri="file://return.none"), + position=Position(line=0, character=0), + ), + ).result() + + assert response is None diff --git a/tests/lsp/test_type_hierarchy.py b/tests/lsp/test_type_hierarchy.py new file mode 100644 index 0000000..e186c7f --- /dev/null +++ b/tests/lsp/test_type_hierarchy.py @@ -0,0 +1,127 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from typing import List, Optional + +from lsprotocol import types as lsp + +from ..conftest import ClientServer + + +TYPE_HIERARCHY_ITEM = lsp.TypeHierarchyItem( + name="test_name", + kind=lsp.SymbolKind.Class, + uri="test_uri", + range=lsp.Range( + start=lsp.Position(line=0, character=0), + end=lsp.Position(line=0, character=6), + ), + selection_range=lsp.Range( + start=lsp.Position(line=0, character=0), + end=lsp.Position(line=0, character=6), + ), +) + + +def check_type_hierarchy_item_response(item): + assert item.name == TYPE_HIERARCHY_ITEM.name + assert item.kind == TYPE_HIERARCHY_ITEM.kind + assert item.uri == TYPE_HIERARCHY_ITEM.uri + assert item.range == TYPE_HIERARCHY_ITEM.range + assert item.selection_range == TYPE_HIERARCHY_ITEM.selection_range + + +class ConfiguredLS(ClientServer): + def __init__(self): + super().__init__() + + @self.server.feature(lsp.TEXT_DOCUMENT_PREPARE_TYPE_HIERARCHY) + def f1( + params: lsp.TypeHierarchyPrepareParams, + ) -> Optional[List[lsp.TypeHierarchyItem]]: + if params.text_document.uri == "file://return.list": + return [TYPE_HIERARCHY_ITEM] + else: + return None + + @self.server.feature(lsp.TYPE_HIERARCHY_SUPERTYPES) + def f2( + params: lsp.TypeHierarchySupertypesParams, + ) -> Optional[List[lsp.TypeHierarchyItem]]: + return [TYPE_HIERARCHY_ITEM] + + @self.server.feature(lsp.TYPE_HIERARCHY_SUBTYPES) + def f3( + params: lsp.TypeHierarchySubtypesParams, + ) -> Optional[List[lsp.TypeHierarchyItem]]: + return [TYPE_HIERARCHY_ITEM] + + +@ConfiguredLS.decorate() +def test_capabilities(client_server): + _, server = client_server + capabilities = server.server_capabilities + assert capabilities.type_hierarchy_provider + + +@ConfiguredLS.decorate() +def test_type_hierarchy_prepare_return_list(client_server): + client, _ = client_server + response = client.lsp.send_request( + lsp.TEXT_DOCUMENT_PREPARE_TYPE_HIERARCHY, + lsp.TypeHierarchyPrepareParams( + text_document=lsp.TextDocumentIdentifier(uri="file://return.list"), + position=lsp.Position(line=0, character=0), + ), + ).result() + + check_type_hierarchy_item_response(response[0]) + + +@ConfiguredLS.decorate() +def test_type_hierarchy_prepare_return_none(client_server): + client, _ = client_server + response = client.lsp.send_request( + lsp.TEXT_DOCUMENT_PREPARE_TYPE_HIERARCHY, + lsp.TypeHierarchyPrepareParams( + text_document=lsp.TextDocumentIdentifier(uri="file://return.none"), + position=lsp.Position(line=0, character=0), + ), + ).result() + + assert response is None + + +@ConfiguredLS.decorate() +def test_type_hierarchy_supertypes(client_server): + client, _ = client_server + response = client.lsp.send_request( + lsp.TYPE_HIERARCHY_SUPERTYPES, + lsp.TypeHierarchySupertypesParams(item=TYPE_HIERARCHY_ITEM), + ).result() + + check_type_hierarchy_item_response(response[0]) + + +@ConfiguredLS.decorate() +def test_type_hierarchy_subtypes(client_server): + client, _ = client_server + response = client.lsp.send_request( + lsp.TYPE_HIERARCHY_SUBTYPES, + lsp.TypeHierarchySubtypesParams(item=TYPE_HIERARCHY_ITEM), + ).result() + + check_type_hierarchy_item_response(response[0]) diff --git a/tests/pyodide_testrunner/.gitignore b/tests/pyodide_testrunner/.gitignore new file mode 100644 index 0000000..704d307 --- /dev/null +++ b/tests/pyodide_testrunner/.gitignore @@ -0,0 +1 @@ +*.whl diff --git a/tests/pyodide_testrunner/index.html b/tests/pyodide_testrunner/index.html new file mode 100644 index 0000000..b3e1b8b --- /dev/null +++ b/tests/pyodide_testrunner/index.html @@ -0,0 +1,56 @@ +<!DOCTYPE html> +<html lang="en"> + +<head> + <meta charset="UTF-8"> + <meta http-equiv="X-UA-Compatible" content="IE=edge"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + + <title>Pygls Testsuite</title> + + <style> + @media (prefers-color-scheme: dark) { + * { + background-color: #222; + color: white; + } + + } + </style> +</head> + +<body> + <div> + <pre id="console"></pre> + </div> + <button id="exit-code" disabled></button> + <script> + let log = document.getElementById("console") + let exitCode = document.getElementById("exit-code") + + function print(event) { + log.innerText += event.data + } + + // Use a web worker to prevent freezing the UI + function runTests(whl) { + let worker = new Worker(`test-runner.js?whl=${whl}`) + worker.addEventListener('message', (event) => { + + if (event.data.exitCode !== undefined) { + exitCode.innerText = event.data.exitCode + exitCode.disabled = false + return + } + + print(event) + }) + } + + let queryParams = new URLSearchParams(window.location.search) + runTests(queryParams.get('whl')) + + </script> +</body> + +</html> diff --git a/tests/pyodide_testrunner/run.py b/tests/pyodide_testrunner/run.py new file mode 100644 index 0000000..7b56374 --- /dev/null +++ b/tests/pyodide_testrunner/run.py @@ -0,0 +1,131 @@ +import os +import pathlib +import shutil +import subprocess +import sys +import tempfile + +from functools import partial +from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer +from multiprocessing import Process, Queue + +from selenium import webdriver +from selenium.webdriver.common.by import By +from selenium.webdriver.support.ui import WebDriverWait +from selenium.common.exceptions import WebDriverException +from selenium.webdriver.support import expected_conditions as EC + + +# Path to the root of the repo. +REPO = pathlib.Path(__file__).parent.parent.parent +BROWSERS = { + "chrome": (webdriver.Chrome, webdriver.ChromeOptions), + "firefox": (webdriver.Firefox, webdriver.FirefoxOptions), +} + + +def build_wheel() -> str: + """Build a wheel package of ``pygls`` and its testsuite. + + In order to test pygls under pyodide, we need to load the code for both pygls and its + testsuite. This is done by building a wheel. + + To avoid messing with the repo this is all done under a temp directory. + """ + + with tempfile.TemporaryDirectory() as tmpdir: + # Copy all required files. + dest = pathlib.Path(tmpdir) + + # So that we don't have to fuss with packaging, copy the test suite into `pygls` + # as a sub module. + directories = [("pygls", "pygls"), ("tests", "pygls/tests")] + + for src, target in directories: + shutil.copytree(REPO / src, dest / target) + + files = ["pyproject.toml", "poetry.lock", "README.md", "ThirdPartyNotices.txt"] + + for src in files: + shutil.copy(REPO / src, dest) + + # Convert the lock file to requirements.txt. + # Ensures reproducible behavour for testing. + subprocess.run( + [ + "poetry", + "export", + "-f", + "requirements.txt", + "--output", + "requirements.txt", + ], + cwd=dest, + ) + subprocess.run( + ["poetry", "run", "pip", "install", "-r", "requirements.txt"], cwd=dest + ) + # Build the wheel + subprocess.run(["poetry", "build", "--format", "wheel"], cwd=dest) + whl = list((dest / "dist").glob("*.whl"))[0] + shutil.copy(whl, REPO / "tests/pyodide_testrunner") + + return whl.name + + +def spawn_http_server(q: Queue, directory: str): + """A http server is needed to serve the files to the browser.""" + + handler_class = partial(SimpleHTTPRequestHandler, directory=directory) + server = ThreadingHTTPServer(("localhost", 0), handler_class) + q.put(server.server_port) + + server.serve_forever() + + +def main(): + exit_code = 1 + whl = build_wheel() + + q = Queue() + server_process = Process( + target=spawn_http_server, + args=(q, REPO / "tests/pyodide_testrunner"), + daemon=True, + ) + server_process.start() + port = q.get() + + print("Running tests...") + try: + driver_cls, options_cls = BROWSERS[os.environ.get("BROWSER", "chrome")] + + options = options_cls() + if "CI" in os.environ: + options.binary_location = "/usr/bin/google-chrome" + options.add_argument("--headless") + + driver = driver_cls(options=options) + driver.get(f"http://localhost:{port}?whl={whl}") + + wait = WebDriverWait(driver, 120) + try: + button = wait.until(EC.element_to_be_clickable((By.ID, "exit-code"))) + exit_code = int(button.text) + except WebDriverException as e: + print(f"Error while running test: {e!r}") + exit_code = 1 + + console = driver.find_element(By.ID, "console") + print(console.text) + finally: + if hasattr(server_process, "kill"): + server_process.kill() + else: + server_process.terminate() + + return exit_code + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pyodide_testrunner/test-runner.js b/tests/pyodide_testrunner/test-runner.js new file mode 100644 index 0000000..dbbc01f --- /dev/null +++ b/tests/pyodide_testrunner/test-runner.js @@ -0,0 +1,38 @@ +importScripts("https://cdn.jsdelivr.net/pyodide/v0.21.3/full/pyodide.js") + +// Used to redirect pyodide's stdout to the webpage. +function patchedStdout(...args) { + postMessage(args[0]) +} + +async function runTests(whl) { + console.log("Loading pyodide") + let pyodide = await loadPyodide({ + indexURL: "https://cdn.jsdelivr.net/pyodide/v0.21.3/full/" + }) + + console.log("Installing dependencies") + await pyodide.loadPackage("micropip") + await pyodide.runPythonAsync(` + import sys + import micropip + + await micropip.install('pytest') + await micropip.install('pytest-asyncio') + await micropip.install('${whl}') + `) + + console.log('Running testsuite') + + // Patch stdout to redirect the output. + pyodide.globals.get('sys').stdout.write = patchedStdout + await pyodide.runPythonAsync(` + import pytest + exit_code = pytest.main(['--color', 'no', '--pyargs', 'pygls.tests']) + `) + + postMessage({ exitCode: pyodide.globals.get('exit_code') }) +} + +let queryParams = new URLSearchParams(self.location.search) +runTests(queryParams.get('whl')) diff --git a/tests/servers/invalid_json.py b/tests/servers/invalid_json.py new file mode 100644 index 0000000..5b40d7a --- /dev/null +++ b/tests/servers/invalid_json.py @@ -0,0 +1,27 @@ +"""This server does nothing but print invalid JSON.""" +import asyncio +import threading +import sys +from concurrent.futures import ThreadPoolExecutor + +from pygls.server import aio_readline + + +def handler(data): + content = 'Content-Length: 5\r\n\r\n{"ll}'.encode("utf8") + sys.stdout.buffer.write(content) + sys.stdout.flush() + + +async def main(): + await aio_readline( + asyncio.get_running_loop(), + ThreadPoolExecutor(), + threading.Event(), + sys.stdin.buffer, + handler, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/servers/large_response.py b/tests/servers/large_response.py new file mode 100644 index 0000000..fd85b62 --- /dev/null +++ b/tests/servers/large_response.py @@ -0,0 +1,36 @@ +"""This server returns a particuarly large response.""" +import asyncio +import threading +import sys +from concurrent.futures import ThreadPoolExecutor + +from pygls.server import aio_readline + + +def handler(data): + payload = dict( + jsonrpc="2.0", + id=1, + result=dict( + numbers=list(range(100_000)), + ), + ) + content = str(payload).replace("'", '"') + message = f"Content-Length: {len(content)}\r\n\r\n{content}".encode("utf8") + + sys.stdout.buffer.write(message) + sys.stdout.flush() + + +async def main(): + await aio_readline( + asyncio.get_running_loop(), + ThreadPoolExecutor(), + threading.Event(), + sys.stdin.buffer, + handler, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..cacd3e0 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,102 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +import pathlib +import sys +from typing import Union + +import pytest +from pygls import IS_PYODIDE + +from pygls.client import JsonRPCClient +from pygls.exceptions import JsonRpcException, PyglsError + + +SERVERS = pathlib.Path(__file__).parent / "servers" + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_PYODIDE, reason="Subprocesses are not available on pyodide.") +async def test_client_detect_server_exit(): + """Ensure that the client detects when the server process exits.""" + + class TestClient(JsonRPCClient): + server_exit_called = False + + async def server_exit(self, server: asyncio.subprocess.Process): + self.server_exit_called = True + assert server.returncode == 0 + + client = TestClient() + await client.start_io(sys.executable, "-c", "print('Hello, World!')") + await asyncio.sleep(1) + await client.stop() + + message = "Expected the `server_exit` method to have been called." + assert client.server_exit_called, message + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_PYODIDE, reason="Subprocesses are not available on pyodide.") +async def test_client_detect_invalid_json(): + """Ensure that the client can detect the case where the server returns invalid + json.""" + + class TestClient(JsonRPCClient): + report_error_called = False + future = None + + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + self.report_error_called = True + self.future.cancel() + + self._server.kill() + self._stop_event.set() + + assert "Unterminated string" in str(error) + + client = TestClient() + await client.start_io(sys.executable, str(SERVERS / "invalid_json.py")) + + future = client.protocol.send_request_async("method/name", {}) + client.future = future + + try: + await future + except asyncio.CancelledError: + pass # Ignore the exception generated by cancelling the future + finally: + await client.stop() + + assert_message = "Expected `report_server_error` to have been called" + assert client.report_error_called, assert_message + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_PYODIDE, reason="Subprocesses are not available on pyodide.") +async def test_client_large_responses(): + """Ensure that the client can correctly handle large responses from a server.""" + + client = JsonRPCClient() + await client.start_io(sys.executable, str(SERVERS / "large_response.py")) + + result = await client.protocol.send_request_async("get/numbers", {}, msg_id=1) + assert len(result.numbers) == 100_000 + + await client.stop() diff --git a/tests/test_document.py b/tests/test_document.py new file mode 100644 index 0000000..f071a2f --- /dev/null +++ b/tests/test_document.py @@ -0,0 +1,390 @@ +############################################################################ +# Original work Copyright 2018 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import re + +from lsprotocol import types +from pygls.workspace import TextDocument, PositionCodec +from .conftest import DOC, DOC_URI + + +def test_document_empty_edit(): + doc = TextDocument("file:///uri", "") + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + range_length=0, + text="f", + ) + doc.apply_change(change) + assert doc.source == "f" + + +def test_document_end_of_file_edit(): + old = ["print 'a'\n", "print 'b'\n"] + doc = TextDocument("file:///uri", "".join(old)) + + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=2, character=0), + end=types.Position(line=2, character=0), + ), + range_length=0, + text="o", + ) + doc.apply_change(change) + + assert doc.lines == [ + "print 'a'\n", + "print 'b'\n", + "o", + ] + + +def test_document_full_edit(): + old = ["def hello(a, b):\n", " print a\n", " print b\n"] + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Full + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), + ), + range_length=0, + text="print a, b", + ) + doc.apply_change(change) + + assert doc.lines == ["print a, b"] + + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Full + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + text="print a, b", + ) + doc.apply_change(change) + + assert doc.lines == ["print a, b"] + + +def test_document_line_edit(): + doc = TextDocument("file:///uri", "itshelloworld") + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=8), + ), + range_length=0, + text="goodbye", + ) + doc.apply_change(change) + assert doc.source == "itsgoodbyeworld" + + +def test_document_lines(): + doc = TextDocument(DOC_URI, DOC) + assert len(doc.lines) == 4 + assert doc.lines[0] == "document\n" + + +def test_document_multiline_edit(): + old = ["def hello(a, b):\n", " print a\n", " print b\n"] + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Incremental + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), + ), + range_length=0, + text="print a, b", + ) + doc.apply_change(change) + + assert doc.lines == ["def hello(a, b):\n", " print a, b\n"] + + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Incremental + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), + ), + text="print a, b", + ) + doc.apply_change(change) + + assert doc.lines == ["def hello(a, b):\n", " print a, b\n"] + + +def test_document_no_edit(): + old = ["def hello(a, b):\n", " print a\n", " print b\n"] + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.None_ + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), + ), + range_length=0, + text="print a, b", + ) + doc.apply_change(change) + + assert doc.lines == old + + +def test_document_props(): + doc = TextDocument(DOC_URI, DOC) + + assert doc.uri == DOC_URI + assert doc.source == DOC + + +def test_document_source_unicode(): + document_mem = TextDocument(DOC_URI, "my source") + document_disk = TextDocument(DOC_URI) + assert isinstance(document_mem.source, type(document_disk.source)) + + +def test_position_from_utf16(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf16) + assert codec.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + assert codec.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=5) + ) == types.Position(line=0, character=4) + + +def test_position_from_utf32(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf32) + assert codec.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + assert codec.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=4) + + +def test_position_from_utf8(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf8) + assert codec.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + assert codec.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=7) + ) == types.Position(line=0, character=4) + + +def test_position_to_utf16(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf16) + assert codec.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + + assert codec.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=5) + + +def test_position_to_utf32(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf32) + assert codec.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + + assert codec.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=4) + + +def test_position_to_utf8(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf8) + assert codec.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + + assert codec.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=6) + + +def test_range_from_utf16(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf16) + assert codec.range_from_client_units( + ['x="😋"'], + types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), + ), + ) == types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), + ) + + range = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), + ) + actual = codec.range_from_client_units(['x="😋😋"'], range) + expected = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), + ) + assert actual == expected + + +def test_range_to_utf16(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf16) + assert codec.range_to_client_units( + ['x="😋"'], + types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), + ), + ) == types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), + ) + + range = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), + ) + actual = codec.range_to_client_units(['x="😋😋"'], range) + expected = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), + ) + assert actual == expected + + +def test_offset_at_position_utf16(): + doc = TextDocument(DOC_URI, DOC) + assert doc.offset_at_position(types.Position(line=0, character=8)) == 8 + assert doc.offset_at_position(types.Position(line=1, character=5)) == 12 + assert doc.offset_at_position(types.Position(line=2, character=0)) == 13 + assert doc.offset_at_position(types.Position(line=2, character=4)) == 17 + assert doc.offset_at_position(types.Position(line=3, character=6)) == 27 + assert doc.offset_at_position(types.Position(line=3, character=7)) == 28 + assert doc.offset_at_position(types.Position(line=3, character=8)) == 28 + assert doc.offset_at_position(types.Position(line=4, character=0)) == 40 + assert doc.offset_at_position(types.Position(line=5, character=0)) == 40 + + +def test_offset_at_position_utf32(): + doc = TextDocument( + DOC_URI, + DOC, + position_codec=PositionCodec(encoding=types.PositionEncodingKind.Utf32), + ) + assert doc.offset_at_position(types.Position(line=0, character=8)) == 8 + assert doc.offset_at_position(types.Position(line=5, character=0)) == 39 + + +def test_offset_at_position_utf8(): + doc = TextDocument( + DOC_URI, + DOC, + position_codec=PositionCodec(encoding=types.PositionEncodingKind.Utf8), + ) + assert doc.offset_at_position(types.Position(line=0, character=8)) == 8 + assert doc.offset_at_position(types.Position(line=5, character=0)) == 41 + + +def test_utf16_to_utf32_position_cast(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf16) + lines = ["", "😋😋", ""] + assert codec.position_from_client_units( + lines, types.Position(line=0, character=0) + ) == types.Position(line=0, character=0) + assert codec.position_from_client_units( + lines, types.Position(line=0, character=1) + ) == types.Position(line=0, character=0) + assert codec.position_from_client_units( + lines, types.Position(line=1, character=0) + ) == types.Position(line=1, character=0) + assert codec.position_from_client_units( + lines, types.Position(line=1, character=2) + ) == types.Position(line=1, character=1) + assert codec.position_from_client_units( + lines, types.Position(line=1, character=3) + ) == types.Position(line=1, character=2) + assert codec.position_from_client_units( + lines, types.Position(line=1, character=4) + ) == types.Position(line=1, character=2) + assert codec.position_from_client_units( + lines, types.Position(line=1, character=100) + ) == types.Position(line=1, character=2) + assert codec.position_from_client_units( + lines, types.Position(line=3, character=0) + ) == types.Position(line=2, character=0) + assert codec.position_from_client_units( + lines, types.Position(line=4, character=10) + ) == types.Position(line=2, character=0) + + +def test_position_for_line_endings(): + codec = PositionCodec(encoding=types.PositionEncodingKind.Utf16) + lines = ["x\r\n", "y\n"] + assert codec.position_from_client_units( + lines, types.Position(line=0, character=10) + ) == types.Position(line=0, character=1) + assert codec.position_from_client_units( + lines, types.Position(line=1, character=10) + ) == types.Position(line=1, character=1) + + +def test_word_at_position(): + """ + Return word under the cursor (or last in line if past the end) + """ + doc = TextDocument(DOC_URI, DOC) + + assert doc.word_at_position(types.Position(line=0, character=8)) == "document" + assert doc.word_at_position(types.Position(line=0, character=1000)) == "document" + assert doc.word_at_position(types.Position(line=1, character=5)) == "for" + assert doc.word_at_position(types.Position(line=2, character=0)) == "testing" + assert doc.word_at_position(types.Position(line=3, character=10)) == "unicode" + assert doc.word_at_position(types.Position(line=4, character=0)) == "" + assert doc.word_at_position(types.Position(line=4, character=0)) == "" + re_start_word = re.compile(r"[A-Za-z_0-9.]*$") + re_end_word = re.compile(r"^[A-Za-z_0-9.]*") + assert ( + doc.word_at_position( + types.Position( + line=3, + character=10, + ), + re_start_word=re_start_word, + re_end_word=re_end_word, + ) + == "unicode." + ) diff --git a/tests/test_feature_manager.py b/tests/test_feature_manager.py new file mode 100644 index 0000000..f69f12a --- /dev/null +++ b/tests/test_feature_manager.py @@ -0,0 +1,782 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +from typing import Any + +import pytest +from pygls.capabilities import ServerCapabilitiesBuilder +from pygls.exceptions import ( + CommandAlreadyRegisteredError, + FeatureAlreadyRegisteredError, + ValidationError, +) +from pygls.feature_manager import ( + FeatureManager, + has_ls_param_or_annotation, + wrap_with_server, +) +from lsprotocol import types as lsp + + +class Temp: + pass + + +def test_has_ls_param_or_annotation(): + def f1(ls, a, b, c): + pass + + def f2(temp: Temp, a, b, c): + pass + + def f3(temp: "Temp", a, b, c): + pass + + assert has_ls_param_or_annotation(f1, None) + assert has_ls_param_or_annotation(f2, Temp) + assert has_ls_param_or_annotation(f3, Temp) + + +def test_register_command_validation_error(feature_manager): + with pytest.raises(ValidationError): + + @feature_manager.command(" \n\t") + def cmd1(): # pylint: disable=unused-variable + pass + + +def test_register_commands(feature_manager): + cmd1_name = "cmd1" + cmd2_name = "cmd2" + + @feature_manager.command(cmd1_name) + def cmd1(): + pass + + @feature_manager.command(cmd2_name) + def cmd2(): + pass + + reg_commands = feature_manager.commands.keys() + + assert cmd1_name in reg_commands + assert cmd2_name in reg_commands + + assert feature_manager.commands[cmd1_name] is cmd1 + assert feature_manager.commands[cmd2_name] is cmd2 + + +def test_register_feature_with_valid_options(feature_manager): + options = lsp.CompletionOptions(trigger_characters=["!"]) + + @feature_manager.feature(lsp.TEXT_DOCUMENT_COMPLETION, options) + def completions(): + pass + + reg_features = feature_manager.features.keys() + reg_feature_options = feature_manager.feature_options.keys() + + assert lsp.TEXT_DOCUMENT_COMPLETION in reg_features + assert lsp.TEXT_DOCUMENT_COMPLETION in reg_feature_options + + assert feature_manager.features[lsp.TEXT_DOCUMENT_COMPLETION] is completions + assert feature_manager.feature_options[lsp.TEXT_DOCUMENT_COMPLETION] is options + + +def test_register_feature_with_wrong_options(feature_manager): + class Options: + pass + + with pytest.raises( + AttributeError, + match=("'Options' object has no attribute 'trigger_characters'"), # noqa + ): + + @feature_manager.feature(lsp.TEXT_DOCUMENT_COMPLETION, Options()) + def completions(): + pass + + +def test_register_features(feature_manager): + @feature_manager.feature(lsp.TEXT_DOCUMENT_COMPLETION) + def completions(): + pass + + @feature_manager.feature(lsp.TEXT_DOCUMENT_CODE_LENS) + def code_lens(): + pass + + reg_features = feature_manager.features.keys() + + assert lsp.TEXT_DOCUMENT_COMPLETION in reg_features + assert lsp.TEXT_DOCUMENT_CODE_LENS in reg_features + + assert feature_manager.features[lsp.TEXT_DOCUMENT_COMPLETION] is completions + assert feature_manager.features[lsp.TEXT_DOCUMENT_CODE_LENS] is code_lens + + +def test_register_same_command_twice_error(feature_manager): + with pytest.raises(CommandAlreadyRegisteredError): + + @feature_manager.command("cmd1") + def cmd1(): # pylint: disable=unused-variable + pass + + @feature_manager.command("cmd1") + def cmd2(): # pylint: disable=unused-variable + pass + + +def test_register_same_feature_twice_error(feature_manager): + with pytest.raises(FeatureAlreadyRegisteredError): + + @feature_manager.feature(lsp.TEXT_DOCUMENT_CODE_ACTION) + def code_action1(): # pylint: disable=unused-variable + pass + + @feature_manager.feature(lsp.TEXT_DOCUMENT_CODE_ACTION) + def code_action2(): # pylint: disable=unused-variable + pass + + +def test_wrap_with_server_async(): + class Server: + pass + + async def f(ls): + assert isinstance(ls, Server) + + wrapped = wrap_with_server(f, Server()) + assert asyncio.iscoroutinefunction(wrapped) + + +def test_wrap_with_server_sync(): + class Server: + pass + + def f(ls): + assert isinstance(ls, Server) + + wrapped = wrap_with_server(f, Server()) + wrapped() + + +def test_wrap_with_server_thread(): + class Server: + pass + + def f(ls): + assert isinstance(ls, Server) + + f.execute_in_thread = True + + wrapped = wrap_with_server(f, Server()) + assert wrapped.execute_in_thread is True + + +def server_capabilities(**kwargs): + """Helper to reduce the amount of boilerplate required to specify the expected + server capabilities by filling in some fields - unless they are explicitly + overriden.""" + + if "text_document_sync" not in kwargs: + kwargs["text_document_sync"] = lsp.TextDocumentSyncOptions( + open_close=False, + save=False, + ) + + if "execute_command_provider" not in kwargs: + kwargs["execute_command_provider"] = lsp.ExecuteCommandOptions(commands=[]) + + if "workspace" not in kwargs: + kwargs["workspace"] = lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions(), + ) + + if "position_encoding" not in kwargs: + kwargs["position_encoding"] = lsp.PositionEncodingKind.Utf16 + + return lsp.ServerCapabilities(**kwargs) + + +@pytest.mark.parametrize( + "method, options, capabilities, expected", + [ + ( + lsp.INITIALIZE, + None, + lsp.ClientCapabilities( + general=lsp.GeneralClientCapabilities( + position_encodings=[lsp.PositionEncodingKind.Utf8] + ) + ), + server_capabilities(position_encoding=lsp.PositionEncodingKind.Utf8), + ), + ( + lsp.INITIALIZE, + None, + lsp.ClientCapabilities( + general=lsp.GeneralClientCapabilities( + position_encodings=[ + lsp.PositionEncodingKind.Utf8, + lsp.PositionEncodingKind.Utf32, + ] + ) + ), + server_capabilities(position_encoding=lsp.PositionEncodingKind.Utf32), + ), + ( + lsp.TEXT_DOCUMENT_DID_SAVE, + lsp.SaveOptions(include_text=True), + lsp.ClientCapabilities(), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=False, save=lsp.SaveOptions(include_text=True) + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_DID_SAVE, + None, + lsp.ClientCapabilities(), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=False, save=True + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_WILL_SAVE, + None, + lsp.ClientCapabilities(), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=False, save=False + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_WILL_SAVE, + None, + lsp.ClientCapabilities( + text_document=lsp.TextDocumentClientCapabilities( + synchronization=lsp.TextDocumentSyncClientCapabilities( + will_save=True + ) + ) + ), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=False, save=False, will_save=True + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_WILL_SAVE_WAIT_UNTIL, + None, + lsp.ClientCapabilities( + text_document=lsp.TextDocumentClientCapabilities( + synchronization=lsp.TextDocumentSyncClientCapabilities( + will_save_wait_until=True + ) + ) + ), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=False, save=False, will_save_wait_until=True + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_DID_OPEN, + None, + lsp.ClientCapabilities(), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=True, save=False + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_DID_CLOSE, + None, + lsp.ClientCapabilities(), + server_capabilities( + text_document_sync=lsp.TextDocumentSyncOptions( + open_close=True, save=False + ) + ), + ), + ( + lsp.TEXT_DOCUMENT_INLAY_HINT, + None, + lsp.ClientCapabilities(), + server_capabilities( + inlay_hint_provider=lsp.InlayHintOptions(resolve_provider=False), + ), + ), + ( + lsp.WORKSPACE_WILL_CREATE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities(), + server_capabilities(), + ), + ( + lsp.WORKSPACE_WILL_CREATE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities( + workspace=lsp.WorkspaceClientCapabilities( + file_operations=lsp.FileOperationClientCapabilities( + will_create=True + ) + ) + ), + server_capabilities( + workspace=lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions( + will_create=lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ) + ), + ) + ), + ), + ( + lsp.WORKSPACE_DID_CREATE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities(), + server_capabilities(), + ), + ( + lsp.WORKSPACE_DID_CREATE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities( + workspace=lsp.WorkspaceClientCapabilities( + file_operations=lsp.FileOperationClientCapabilities(did_create=True) + ) + ), + server_capabilities( + workspace=lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions( + did_create=lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ) + ), + ) + ), + ), + ( + lsp.WORKSPACE_WILL_DELETE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities(), + server_capabilities(), + ), + ( + lsp.WORKSPACE_WILL_DELETE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities( + workspace=lsp.WorkspaceClientCapabilities( + file_operations=lsp.FileOperationClientCapabilities( + will_delete=True + ) + ) + ), + server_capabilities( + workspace=lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions( + will_delete=lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ) + ), + ) + ), + ), + ( + lsp.WORKSPACE_DID_DELETE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities(), + server_capabilities(), + ), + ( + lsp.WORKSPACE_DID_DELETE_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities( + workspace=lsp.WorkspaceClientCapabilities( + file_operations=lsp.FileOperationClientCapabilities(did_delete=True) + ) + ), + server_capabilities( + workspace=lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions( + did_delete=lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ) + ), + ) + ), + ), + ( + lsp.WORKSPACE_WILL_RENAME_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities(), + server_capabilities(), + ), + ( + lsp.WORKSPACE_WILL_RENAME_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities( + workspace=lsp.WorkspaceClientCapabilities( + file_operations=lsp.FileOperationClientCapabilities( + will_rename=True + ) + ) + ), + server_capabilities( + workspace=lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions( + will_rename=lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ) + ), + ) + ), + ), + ( + lsp.WORKSPACE_DID_RENAME_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities(), + server_capabilities(), + ), + ( + lsp.WORKSPACE_DID_RENAME_FILES, + lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ), + lsp.ClientCapabilities( + workspace=lsp.WorkspaceClientCapabilities( + file_operations=lsp.FileOperationClientCapabilities(did_rename=True) + ) + ), + server_capabilities( + workspace=lsp.ServerCapabilitiesWorkspaceType( + workspace_folders=lsp.WorkspaceFoldersServerCapabilities( + supported=True, change_notifications=True + ), + file_operations=lsp.FileOperationOptions( + did_rename=lsp.FileOperationRegistrationOptions( + filters=[ + lsp.FileOperationFilter( + pattern=lsp.FileOperationPattern(glob="**/*.py") + ) + ] + ) + ), + ) + ), + ), + ( + lsp.WORKSPACE_SYMBOL, + None, + lsp.ClientCapabilities(), + server_capabilities( + workspace_symbol_provider=lsp.WorkspaceSymbolOptions( + resolve_provider=False, + ), + ), + ), + ( + lsp.TEXT_DOCUMENT_DIAGNOSTIC, + None, + lsp.ClientCapabilities(), + server_capabilities( + diagnostic_provider=lsp.DiagnosticOptions( + inter_file_dependencies=False, + workspace_diagnostics=False, + ), + ), + ), + ( + lsp.TEXT_DOCUMENT_DIAGNOSTIC, + lsp.DiagnosticOptions( + workspace_diagnostics=True, + inter_file_dependencies=True, + ), + lsp.ClientCapabilities(), + server_capabilities( + diagnostic_provider=lsp.DiagnosticOptions( + inter_file_dependencies=True, + workspace_diagnostics=False, + ), + ), + ), + ( + lsp.TEXT_DOCUMENT_ON_TYPE_FORMATTING, + None, + lsp.ClientCapabilities(), + server_capabilities( + document_on_type_formatting_provider=None, + ), + ), + ( + lsp.TEXT_DOCUMENT_ON_TYPE_FORMATTING, + lsp.DocumentOnTypeFormattingOptions(first_trigger_character=":"), + lsp.ClientCapabilities(), + server_capabilities( + document_on_type_formatting_provider=lsp.DocumentOnTypeFormattingOptions( + first_trigger_character=":", + ), + ), + ), + ], +) +def test_register_feature( + feature_manager: FeatureManager, + method: str, + options: Any, + capabilities: lsp.ClientCapabilities, + expected: lsp.ServerCapabilities, +): + """Ensure that we can register features while specifying their associated + options and that `pygls` is able to correctly build the corresponding server + capabilities. + + Parameters + ---------- + feature_manager + The feature manager to use + + method + The method to register the feature handler for. + + options + The method options to use + + capabilities + The client capabilities to use when building the server's capabilities. + + expected + The expected server capabilties we are expecting to see. + """ + + @feature_manager.feature(method, options) + def _(): + pass + + actual = ServerCapabilitiesBuilder( + capabilities, + feature_manager.features.keys(), + feature_manager.feature_options, + [], + None, + None, + ).build() + + assert expected == actual + + +def test_register_inlay_hint_resolve(feature_manager: FeatureManager): + @feature_manager.feature(lsp.TEXT_DOCUMENT_INLAY_HINT) + def _(): + pass + + @feature_manager.feature(lsp.INLAY_HINT_RESOLVE) + def _(): + pass + + expected = server_capabilities( + inlay_hint_provider=lsp.InlayHintOptions(resolve_provider=True), + ) + + actual = ServerCapabilitiesBuilder( + lsp.ClientCapabilities(), + feature_manager.features.keys(), + feature_manager.feature_options, + [], + None, + None, + ).build() + + assert expected == actual + + +def test_register_workspace_symbol_resolve(feature_manager: FeatureManager): + @feature_manager.feature(lsp.WORKSPACE_SYMBOL) + def _(): + pass + + @feature_manager.feature(lsp.WORKSPACE_SYMBOL_RESOLVE) + def _(): + pass + + expected = server_capabilities( + workspace_symbol_provider=lsp.WorkspaceSymbolOptions(resolve_provider=True), + ) + + actual = ServerCapabilitiesBuilder( + lsp.ClientCapabilities(), + feature_manager.features.keys(), + feature_manager.feature_options, + [], + None, + None, + ).build() + + assert expected == actual + + +def test_register_workspace_diagnostics(feature_manager: FeatureManager): + @feature_manager.feature( + lsp.TEXT_DOCUMENT_DIAGNOSTIC, + lsp.DiagnosticOptions( + identifier="example", + inter_file_dependencies=False, + workspace_diagnostics=False, + ), + ) + def _(): + pass + + @feature_manager.feature(lsp.WORKSPACE_DIAGNOSTIC) + def _(): + pass + + expected = server_capabilities( + diagnostic_provider=lsp.DiagnosticOptions( + identifier="example", + inter_file_dependencies=False, + workspace_diagnostics=True, + ), + ) + + actual = ServerCapabilitiesBuilder( + lsp.ClientCapabilities(), + feature_manager.features.keys(), + feature_manager.feature_options, + [], + None, + None, + ).build() + + assert expected == actual diff --git a/tests/test_language_server.py b/tests/test_language_server.py new file mode 100644 index 0000000..5271ff5 --- /dev/null +++ b/tests/test_language_server.py @@ -0,0 +1,144 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import pathlib +from time import sleep + +import pytest + +from pygls import IS_PYODIDE +from lsprotocol.types import ( + INITIALIZE, + TEXT_DOCUMENT_DID_OPEN, + WORKSPACE_EXECUTE_COMMAND, +) +from lsprotocol.types import ( + ClientCapabilities, + DidOpenTextDocumentParams, + ExecuteCommandParams, + InitializeParams, + TextDocumentItem, +) +from pygls.protocol import LanguageServerProtocol +from pygls.server import LanguageServer +from . import CMD_ASYNC, CMD_SYNC, CMD_THREAD + + +def _initialize_server(server): + server.lsp.lsp_initialize( + InitializeParams( + process_id=1234, + root_uri=pathlib.Path(__file__).parent.as_uri(), + capabilities=ClientCapabilities(), + ) + ) + + +def test_bf_initialize(client_server): + client, server = client_server + root_uri = pathlib.Path(__file__).parent.as_uri() + process_id = 1234 + + response = client.lsp.send_request( + INITIALIZE, + InitializeParams( + process_id=process_id, + root_uri=root_uri, + capabilities=ClientCapabilities(), + ), + ).result() + + assert server.process_id == process_id + assert server.workspace.root_uri == root_uri + assert response.capabilities is not None + + +def test_bf_text_document_did_open(client_server): + client, server = client_server + + _initialize_server(server) + + client.lsp.notify( + TEXT_DOCUMENT_DID_OPEN, + DidOpenTextDocumentParams( + text_document=TextDocumentItem( + uri=__file__, language_id="python", version=1, text="test" + ) + ), + ) + + sleep(1) + + assert len(server.lsp.workspace.text_documents) == 1 + + document = server.workspace.get_text_document(__file__) + assert document.uri == __file__ + assert document.version == 1 + assert document.source == "test" + assert document.language_id == "python" + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +def test_command_async(client_server): + client, server = client_server + + is_called, thread_id = client.lsp.send_request( + WORKSPACE_EXECUTE_COMMAND, ExecuteCommandParams(command=CMD_ASYNC) + ).result() + + assert is_called + assert thread_id == server.thread_id + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +def test_command_sync(client_server): + client, server = client_server + + is_called, thread_id = client.lsp.send_request( + WORKSPACE_EXECUTE_COMMAND, ExecuteCommandParams(command=CMD_SYNC) + ).result() + + assert is_called + assert thread_id == server.thread_id + + +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +def test_command_thread(client_server): + client, server = client_server + + is_called, thread_id = client.lsp.send_request( + WORKSPACE_EXECUTE_COMMAND, ExecuteCommandParams(command=CMD_THREAD) + ).result() + + assert is_called + assert thread_id != server.thread_id + + +def test_allow_custom_protocol_derived_from_lsp(): + class CustomProtocol(LanguageServerProtocol): + pass + + server = LanguageServer("pygls-test", "v1", protocol_cls=CustomProtocol) + + assert isinstance(server.lsp, CustomProtocol) + + +def test_forbid_custom_protocol_not_derived_from_lsp(): + class CustomProtocol: + pass + + with pytest.raises(TypeError): + LanguageServer("pygls-test", "v1", protocol_cls=CustomProtocol) diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..63b689f --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,660 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import io +import json +from concurrent.futures import Future +from pathlib import Path +from typing import Optional +from unittest.mock import Mock + +import attrs +import pytest + +from pygls.exceptions import JsonRpcException, JsonRpcInvalidParams +from lsprotocol.types import ( + PROGRESS, + TEXT_DOCUMENT_COMPLETION, + ClientCapabilities, + CompletionItem, + CompletionItemKind, + CompletionParams, + InitializeParams, + InitializeResult, + ProgressParams, + Position, + ShutdownResponse, + TextDocumentCompletionResponse, + TextDocumentIdentifier, + WorkDoneProgressBegin, +) +from pygls.protocol import ( + default_converter, + JsonRPCProtocol, + JsonRPCRequestMessage, + JsonRPCResponseMessage, + JsonRPCNotification, +) + +EXAMPLE_NOTIFICATION = "example/notification" +EXAMPLE_REQUEST = "example/request" + + +@attrs.define +class IntResult: + id: str + result: int + jsonrpc: str = attrs.field(default="2.0") + + +@attrs.define +class ExampleParams: + @attrs.define + class InnerType: + inner_field: str + + field_a: str + field_b: Optional[InnerType] = None + + +@attrs.define +class ExampleNotification: + jsonrpc: str = attrs.field(default="2.0") + method: str = EXAMPLE_NOTIFICATION + params: ExampleParams = attrs.field(default=None) + + +@attrs.define +class ExampleRequest: + id: str + jsonrpc: str = attrs.field(default="2.0") + method: str = EXAMPLE_REQUEST + params: ExampleParams = attrs.field(default=None) + + +EXAMPLE_LSP_METHODS_MAP = { + EXAMPLE_NOTIFICATION: (ExampleNotification, None, ExampleParams, None), + EXAMPLE_REQUEST: (ExampleRequest, None, ExampleParams, None), +} + + +class ExampleProtocol(JsonRPCProtocol): + def get_message_type(self, method: str): + return EXAMPLE_LSP_METHODS_MAP.get(method, (None,))[0] + + +@pytest.fixture() +def protocol(): + return ExampleProtocol(None, default_converter()) + + +def test_deserialize_notification_message_valid_params(protocol): + params = f""" + {{ + "jsonrpc": "2.0", + "method": "{EXAMPLE_NOTIFICATION}", + "params": {{ + "fieldA": "test_a", + "fieldB": {{ + "innerField": "test_inner" + }} + }} + }} + """ + + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance( + result, ExampleNotification + ), f"Expected FeatureRequest instance, got {result}" + assert result.jsonrpc == "2.0" + assert result.method == EXAMPLE_NOTIFICATION + + assert isinstance(result.params, ExampleParams) + assert result.params.field_a == "test_a" + + assert isinstance(result.params.field_b, ExampleParams.InnerType) + assert result.params.field_b.inner_field == "test_inner" + + +def test_deserialize_notification_message_unknown_type(protocol): + params = """ + { + "jsonrpc": "2.0", + "method": "random", + "params": { + "field_a": "test_a", + "field_b": { + "inner_field": "test_inner" + } + } + } + """ + + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, JsonRPCNotification) + assert result.jsonrpc == "2.0" + assert result.method == "random" + + assert result.params.field_a == "test_a" + assert result.params.field_b.inner_field == "test_inner" + + +def test_deserialize_notification_message_bad_params_should_raise_error(protocol): + params = f""" + {{ + "jsonrpc": "2.0", + "method": "{EXAMPLE_NOTIFICATION}", + "params": {{ + "field_a": "test_a", + "field_b": {{ + "wrong_field_name": "test_inner" + }} + }} + }} + """ + + with pytest.raises(JsonRpcInvalidParams): + json.loads(params, object_hook=protocol._deserialize_message) + + +def test_deserialize_response_message_custom_converter(): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "result": "1" + } + """ + + # Just for fun, let's create a converter that reverses all the keys in a dict. + # + @attrs.define + class egasseM: + cprnosj: str + di: str + tluser: str + + def structure_hook(obj, cls): + params = {k[::-1]: v for k, v in obj.items()} + return cls(**params) + + def custom_converter(): + converter = default_converter() + converter.register_structure_hook(egasseM, structure_hook) + return converter + + protocol = JsonRPCProtocol(None, custom_converter()) + protocol._result_types["id"] = egasseM + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, egasseM) + assert result.cprnosj == "2.0" + assert result.di == "id" + assert result.tluser == "1" + + +@pytest.mark.parametrize( + "method, params, expected", + [ + ( + # Known notification type. + PROGRESS, + ProgressParams( + token="id1", + value=WorkDoneProgressBegin( + title="Begin progress", + percentage=0, + ), + ), + { + "jsonrpc": "2.0", + "method": "$/progress", + "params": { + "token": "id1", + "value": { + "kind": "begin", + "percentage": 0, + "title": "Begin progress", + }, + }, + }, + ), + ( + # Custom notification type. + EXAMPLE_NOTIFICATION, + ExampleParams( + field_a="field one", + field_b=ExampleParams.InnerType(inner_field="field two"), + ), + { + "jsonrpc": "2.0", + "method": EXAMPLE_NOTIFICATION, + "params": { + "fieldA": "field one", + "fieldB": { + "innerField": "field two", + }, + }, + }, + ), + ( + # Custom notification with dict params. + EXAMPLE_NOTIFICATION, + {"fieldA": "field one", "fieldB": {"innerField": "field two"}}, + { + "jsonrpc": "2.0", + "method": EXAMPLE_NOTIFICATION, + "params": { + "fieldA": "field one", + "fieldB": { + "innerField": "field two", + }, + }, + }, + ), + ], +) +def test_serialize_notification_message(method, params, expected): + """ + Ensure that we can serialize notification messages, retaining all + expected fields. + """ + + buffer = io.StringIO() + + protocol = JsonRPCProtocol(None, default_converter()) + protocol._send_only_body = True + protocol.connection_made(buffer) + + protocol.notify(method, params=params) + actual = json.loads(buffer.getvalue()) + + assert actual == expected + + +def test_deserialize_response_message(protocol): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "result": "1" + } + """ + protocol._result_types["id"] = IntResult + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, IntResult) + assert result.jsonrpc == "2.0" + assert result.id == "id" + assert result.result == 1 + + +def test_deserialize_response_message_unknown_type(protocol): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "result": { + "field_a": "test_a", + "field_b": { + "inner_field": "test_inner" + } + } + } + """ + protocol._result_types["id"] = JsonRPCResponseMessage + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, JsonRPCResponseMessage) + assert result.jsonrpc == "2.0" + assert result.id == "id" + + assert result.result.field_a == "test_a" + assert result.result.field_b.inner_field == "test_inner" + + +def test_deserialize_request_message_with_registered_type(protocol): + params = f""" + {{ + "jsonrpc": "2.0", + "id": "id", + "method": "{EXAMPLE_REQUEST}", + "params": {{ + "fieldA": "test_a", + "fieldB": {{ + "innerField": "test_inner" + }} + }} + }} + """ + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, ExampleRequest) + assert result.jsonrpc == "2.0" + assert result.id == "id" + assert result.method == EXAMPLE_REQUEST + + assert isinstance(result.params, ExampleParams) + assert result.params.field_a == "test_a" + + assert isinstance(result.params.field_b, ExampleParams.InnerType) + assert result.params.field_b.inner_field == "test_inner" + + +def test_deserialize_request_message_without_registered_type(protocol): + params = """ + { + "jsonrpc": "2.0", + "id": "id", + "method": "random", + "params": { + "field_a": "test_a", + "field_b": { + "inner_field": "test_inner" + } + } + } + """ + result = json.loads(params, object_hook=protocol._deserialize_message) + + assert isinstance(result, JsonRPCRequestMessage) + assert result.jsonrpc == "2.0" + assert result.id == "id" + assert result.method == "random" + + assert result.params.field_a == "test_a" + assert result.params.field_b.inner_field == "test_inner" + + +@pytest.mark.parametrize( + "msg_type, result, expected", + [ + (ShutdownResponse, None, {"jsonrpc": "2.0", "id": "1", "result": None}), + ( + TextDocumentCompletionResponse, + [ + CompletionItem(label="example-one"), + CompletionItem( + label="example-two", + kind=CompletionItemKind.Class, + preselect=False, + deprecated=True, + ), + ], + { + "jsonrpc": "2.0", + "id": "1", + "result": [ + {"label": "example-one"}, + { + "label": "example-two", + "kind": 7, # CompletionItemKind.Class + "preselect": False, + "deprecated": True, + }, + ], + }, + ), + ( # Unknown type with object params. + JsonRPCResponseMessage, + ExampleParams( + field_a="field one", + field_b=ExampleParams.InnerType(inner_field="field two"), + ), + { + "jsonrpc": "2.0", + "id": "1", + "result": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ( # Unknown type with dict params. + JsonRPCResponseMessage, + {"fieldA": "field one", "fieldB": {"innerField": "field two"}}, + { + "jsonrpc": "2.0", + "id": "1", + "result": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ], +) +def test_serialize_response_message(msg_type, result, expected): + """ + Ensure that we can serialize response messages, retaining all expected + fields. + """ + + buffer = io.StringIO() + + protocol = JsonRPCProtocol(None, default_converter()) + protocol._send_only_body = True + protocol.connection_made(buffer) + + protocol._result_types["1"] = msg_type + + protocol._send_response("1", result=result) + actual = json.loads(buffer.getvalue()) + + assert actual == expected + + +@pytest.mark.parametrize( + "method, params, expected", + [ + ( + TEXT_DOCUMENT_COMPLETION, + CompletionParams( + text_document=TextDocumentIdentifier(uri="file:///file.txt"), + position=Position(line=1, character=0), + ), + { + "jsonrpc": "2.0", + "id": "1", + "method": TEXT_DOCUMENT_COMPLETION, + "params": { + "textDocument": {"uri": "file:///file.txt"}, + "position": {"line": 1, "character": 0}, + }, + }, + ), + ( # Unknown type with object params. + EXAMPLE_REQUEST, + ExampleParams( + field_a="field one", + field_b=ExampleParams.InnerType(inner_field="field two"), + ), + { + "jsonrpc": "2.0", + "id": "1", + "method": EXAMPLE_REQUEST, + "params": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ( # Unknown type with dict params. + EXAMPLE_REQUEST, + {"fieldA": "field one", "fieldB": {"innerField": "field two"}}, + { + "jsonrpc": "2.0", + "id": "1", + "method": EXAMPLE_REQUEST, + "params": { + "fieldA": "field one", + "fieldB": {"innerField": "field two"}, + }, + }, + ), + ], +) +def test_serialize_request_message(method, params, expected): + """ + Ensure that we can serialize request messages, retaining all expected + fields. + """ + + buffer = io.StringIO() + + protocol = JsonRPCProtocol(None, default_converter()) + protocol._send_only_body = True + protocol.connection_made(buffer) + + protocol.send_request(method, params, callback=None, msg_id="1") + actual = json.loads(buffer.getvalue()) + + assert actual == expected + + +def test_data_received_without_content_type(client_server): + _, server = client_server + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "test", + "params": 1, + } + ) + message = "\r\n".join( + ( + "Content-Length: " + str(len(body)), + "", + body, + ) + ) + data = bytes(message, "utf-8") + server.lsp.data_received(data) + + +def test_data_received_content_type_first_should_handle_message(client_server): + _, server = client_server + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "test", + "params": 1, + } + ) + message = "\r\n".join( + ( + "Content-Type: application/vscode-jsonrpc; charset=utf-8", + "Content-Length: " + str(len(body)), + "", + body, + ) + ) + data = bytes(message, "utf-8") + server.lsp.data_received(data) + + +def dummy_message(param=1): + body = json.dumps( + { + "jsonrpc": "2.0", + "method": "test", + "params": param, + } + ) + message = "\r\n".join( + ( + "Content-Length: " + str(len(body)), + "Content-Type: application/vscode-jsonrpc; charset=utf-8", + "", + body, + ) + ) + return bytes(message, "utf-8") + + +def test_data_received_single_message_should_handle_message(client_server): + _, server = client_server + data = dummy_message() + server.lsp.data_received(data) + + +def test_data_received_partial_message_should_handle_message(client_server): + _, server = client_server + data = dummy_message() + partial = len(data) - 5 + server.lsp.data_received(data[:partial]) + server.lsp.data_received(data[partial:]) + + +def test_data_received_multi_message_should_handle_messages(client_server): + _, server = client_server + messages = (dummy_message(i) for i in range(3)) + data = b"".join(messages) + server.lsp.data_received(data) + + +def test_data_received_error_should_raise_jsonrpc_error(client_server): + _, server = client_server + body = json.dumps( + { + "jsonrpc": "2.0", + "id": "err", + "error": { + "code": -1, + "message": "message for you sir", + }, + } + ) + message = "\r\n".join( + [ + "Content-Length: " + str(len(body)), + "Content-Type: application/vscode-jsonrpc; charset=utf-8", + "", + body, + ] + ).encode("utf-8") + future = server.lsp._request_futures["err"] = Future() + server.lsp.data_received(message) + with pytest.raises(JsonRpcException, match="message for you sir"): + future.result() + + +def test_initialize_should_return_server_capabilities(client_server): + _, server = client_server + params = InitializeParams( + process_id=1234, + root_uri=Path(__file__).parent.as_uri(), + capabilities=ClientCapabilities(), + ) + + server_capabilities = server.lsp.lsp_initialize(params) + + assert isinstance(server_capabilities, InitializeResult) + + +def test_ignore_unknown_notification(client_server): + _, server = client_server + + fn = server.lsp._execute_notification + server.lsp._execute_notification = Mock() + + server.lsp._handle_notification("random/notification", None) + assert not server.lsp._execute_notification.called + + # Remove mock + server.lsp._execute_notification = fn diff --git a/tests/test_server_connection.py b/tests/test_server_connection.py new file mode 100644 index 0000000..1fda258 --- /dev/null +++ b/tests/test_server_connection.py @@ -0,0 +1,135 @@ +import asyncio +import json +import os +from threading import Thread +from unittest.mock import Mock + +import pytest + +from pygls import IS_PYODIDE +from pygls.server import LanguageServer + +try: + import websockets + + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +async def test_tcp_connection_lost(): + loop = asyncio.new_event_loop() + + server = LanguageServer("pygls-test", "v1", loop=loop) + + server.lsp.connection_made = Mock() + server.lsp.connection_lost = Mock() + + # Run the server over TCP in a separate thread + server_thread = Thread( + target=server.start_tcp, + args=( + "127.0.0.1", + 0, + ), + ) + server_thread.daemon = True + server_thread.start() + + # Wait for server to be ready + while server._server is None: + await asyncio.sleep(0.5) + + # Simulate client's connection + port = server._server.sockets[0].getsockname()[1] + reader, writer = await asyncio.open_connection("127.0.0.1", port) + await asyncio.sleep(1) + + assert server.lsp.connection_made.called + + # Socket is closed (client's process is terminated) + writer.close() + await asyncio.sleep(1) + + assert server.lsp.connection_lost.called + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_PYODIDE, reason="threads are not available in pyodide.") +async def test_io_connection_lost(): + # Client to Server pipe. + csr, csw = os.pipe() + # Server to client pipe. + scr, scw = os.pipe() + + server = LanguageServer("pygls-test", "v1", loop=asyncio.new_event_loop()) + server.lsp.connection_made = Mock() + server_thread = Thread( + target=server.start_io, args=(os.fdopen(csr, "rb"), os.fdopen(scw, "wb")) + ) + server_thread.daemon = True + server_thread.start() + + # Wait for server to be ready + while not server.lsp.connection_made.called: + await asyncio.sleep(0.5) + + # Pipe is closed (client's process is terminated) + os.close(csw) + server_thread.join() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + IS_PYODIDE or not WEBSOCKETS_AVAILABLE, + reason="threads are not available in pyodide", +) +async def test_ws_server(): + """Smoke test to ensure we can send/receive messages over websockets""" + + loop = asyncio.new_event_loop() + server = LanguageServer("pygls-test", "v1", loop=loop) + + # Run the server over Websockets in a separate thread + server_thread = Thread( + target=server.start_ws, + args=( + "127.0.0.1", + 0, + ), + ) + server_thread.daemon = True + server_thread.start() + + # Wait for server to be ready + while server._server is None: + await asyncio.sleep(0.5) + + port = server._server.sockets[0].getsockname()[1] + # Simulate client's connection + async with websockets.connect(f"ws://127.0.0.1:{port}") as connection: + # Send an 'initialize' request + msg = dict( + jsonrpc="2.0", id=1, method="initialize", params=dict(capabilities=dict()) + ) + await connection.send(json.dumps(msg)) + + response = await connection.recv() + assert "result" in response + + # Shut the server down + msg = dict( + jsonrpc="2.0", id=2, method="shutdown", params=dict(capabilities=dict()) + ) + await connection.send(json.dumps(msg)) + + response = await connection.recv() + assert "result" in response + + # Finally, tell it to exit + msg = dict(jsonrpc="2.0", id=2, method="exit", params=None) + await connection.send(json.dumps(msg)) + + server_thread.join() diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..0958493 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,86 @@ +############################################################################ +# Original work Copyright 2018 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from lsprotocol.types import Location, Position, Range + + +def test_position(): + assert Position(line=1, character=2) == Position(line=1, character=2) + assert Position(line=1, character=2) != Position(line=2, character=2) + assert Position(line=1, character=2) <= Position(line=2, character=2) + assert Position(line=2, character=2) >= Position(line=2, character=0) + assert Position(line=1, character=2) != "something else" + assert "1:2" == repr(Position(line=1, character=2)) + + +def test_range(): + assert Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ) == Range(start=Position(line=1, character=2), end=Position(line=3, character=4)) + assert Range( + start=Position(line=0, character=2), end=Position(line=3, character=4) + ) != Range(start=Position(line=1, character=2), end=Position(line=3, character=4)) + assert ( + Range(start=Position(line=0, character=2), end=Position(line=3, character=4)) + != "something else" + ) + assert "1:2-3:4" == repr( + Range(start=Position(line=1, character=2), end=Position(line=3, character=4)) + ) + + +def test_location(): + assert Location( + uri="file:///document.txt", + range=Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ), + ) == Location( + uri="file:///document.txt", + range=Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ), + ) + assert Location( + uri="file:///document.txt", + range=Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ), + ) != Location( + uri="file:///another.txt", + range=Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ), + ) + assert ( + Location( + uri="file:///document.txt", + range=Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ), + ) + != "something else" + ) + assert "file:///document.txt:1:2-3:4" == repr( + Location( + uri="file:///document.txt", + range=Range( + start=Position(line=1, character=2), end=Position(line=3, character=4) + ), + ) + ) diff --git a/tests/test_uris.py b/tests/test_uris.py new file mode 100644 index 0000000..16c8bd4 --- /dev/null +++ b/tests/test_uris.py @@ -0,0 +1,87 @@ +############################################################################ +# Original work Copyright 2018 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import pytest + +from pygls import uris +from . import unix_only, windows_only + + +@unix_only +@pytest.mark.parametrize( + "path,uri", + [ + ("/foo/bar", "file:///foo/bar"), + ("/foo/space ?bar", "file:///foo/space%20%3Fbar"), + ], +) +def test_from_fs_path(path, uri): + assert uris.from_fs_path(path) == uri + + +@unix_only +@pytest.mark.parametrize( + "uri,path", + [ + ("file:///foo/bar#frag", "/foo/bar"), + ("file:/foo/bar#frag", "/foo/bar"), + ("file:/foo/space%20%3Fbar#frag", "/foo/space ?bar"), + ], +) +def test_to_fs_path(uri, path): + assert uris.to_fs_path(uri) == path + + +@pytest.mark.parametrize( + "uri,kwargs,new_uri", + [ + ("file:///foo/bar", {"path": "/baz/boo"}, "file:///baz/boo"), + ( + "file:///D:/hello%20world.py", + {"path": "D:/hello universe.py"}, + "file:///d:/hello%20universe.py", + ), + ], +) +def test_uri_with(uri, kwargs, new_uri): + assert uris.uri_with(uri, **kwargs) == new_uri + + +@windows_only +@pytest.mark.parametrize( + "path,uri", + [ + ("c:\\far\\boo", "file:///c:/far/boo"), + ("C:\\far\\space ?boo", "file:///c:/far/space%20%3Fboo"), + ], +) +def test_win_from_fs_path(path, uri): + assert uris.from_fs_path(path) == uri + + +@windows_only +@pytest.mark.parametrize( + "uri,path", + [ + ("file:///c:/far/boo", "c:\\far\\boo"), + ("file:///C:/far/boo", "c:\\far\\boo"), + ("file:///C:/far/space%20%3Fboo", "c:\\far\\space ?boo"), + ], +) +def test_win_to_fs_path(uri, path): + assert uris.to_fs_path(uri) == path diff --git a/tests/test_workspace.py b/tests/test_workspace.py new file mode 100644 index 0000000..52d50f3 --- /dev/null +++ b/tests/test_workspace.py @@ -0,0 +1,442 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import os + +import pytest +from lsprotocol import types + +from pygls import uris +from pygls.workspace import Workspace + +DOC_URI = uris.from_fs_path(__file__) +DOC_TEXT = """test""" +DOC = types.TextDocumentItem( + uri=DOC_URI, language_id="plaintext", version=0, text=DOC_TEXT +) +NOTEBOOK = types.NotebookDocument( + uri="file:///path/to/notebook.ipynb", + notebook_type="jupyter-notebook", + version=0, + cells=[ + types.NotebookCell( + kind=types.NotebookCellKind.Code, + document="nb-cell-scheme://path/to/notebook.ipynb#cv32321", + ), + types.NotebookCell( + kind=types.NotebookCellKind.Code, + document="nb-cell-scheme://path/to/notebook.ipynb#cp897h32", + ), + ], +) +NB_CELL_1 = types.TextDocumentItem( + uri="nb-cell-scheme://path/to/notebook.ipynb#cv32321", + language_id="python", + version=0, + text="# cell 1", +) +NB_CELL_2 = types.TextDocumentItem( + uri="nb-cell-scheme://path/to/notebook.ipynb#cp897h32", + language_id="python", + version=0, + text="# cell 2", +) +NB_CELL_3 = types.TextDocumentItem( + uri="nb-cell-scheme://path/to/notebook.ipynb#cq343eeds", + language_id="python", + version=0, + text="# cell 3", +) + + +def test_add_folder(workspace): + dir_uri = os.path.dirname(DOC_URI) + dir_name = "test" + workspace.add_folder(types.WorkspaceFolder(uri=dir_uri, name=dir_name)) + assert workspace.folders[dir_uri].name == dir_name + + +def test_get_notebook_document_by_uri(workspace): + """Ensure that we can get a notebook given its uri.""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook == NOTEBOOK + + +@pytest.mark.parametrize( + "cell,expected", + [ + (NB_CELL_1, NOTEBOOK), + (NB_CELL_2, NOTEBOOK), + (NB_CELL_3, None), + (DOC, None), + ], +) +def test_get_notebook_document_by_cell_uri(workspace, cell, expected): + """Ensure that we can get a notebook given a uri of one of its cells""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + notebook = workspace.get_notebook_document(cell_uri=cell.uri) + assert notebook == expected + + +def test_get_text_document(workspace): + workspace.put_text_document(DOC) + + assert workspace.get_text_document(DOC_URI).source == DOC_TEXT + + +def test_get_missing_document(tmpdir, workspace): + doc_path = tmpdir.join("test_document.py") + doc_path.write(DOC_TEXT) + doc_uri = uris.from_fs_path(str(doc_path)) + assert workspace.get_text_document(doc_uri).source == DOC_TEXT + + +def test_put_notebook_document(workspace): + """Ensure that we can add notebook documents to the workspace correctly.""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + assert NOTEBOOK.uri in workspace._notebook_documents + assert NB_CELL_1.uri in workspace._text_documents + assert NB_CELL_2.uri in workspace._text_documents + + +def test_put_text_document(workspace): + workspace.put_text_document(DOC) + assert DOC_URI in workspace._text_documents + + +def test_remove_folder(workspace): + dir_uri = os.path.dirname(DOC_URI) + dir_name = "test" + workspace.add_folder(types.WorkspaceFolder(uri=dir_uri, name=dir_name)) + workspace.remove_folder(dir_uri) + + assert dir_uri not in workspace.folders + + +def test_remove_notebook_document(workspace): + """Ensure that we can correctly remove a document from the workspace.""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + assert NOTEBOOK.uri in workspace._notebook_documents + assert NB_CELL_1.uri in workspace._text_documents + assert NB_CELL_2.uri in workspace._text_documents + + params = types.DidCloseNotebookDocumentParams( + notebook_document=types.NotebookDocumentIdentifier(uri=NOTEBOOK.uri), + cell_text_documents=[ + types.TextDocumentIdentifier(uri=NB_CELL_1.uri), + types.TextDocumentIdentifier(uri=NB_CELL_2.uri), + ], + ) + workspace.remove_notebook_document(params) + + assert NOTEBOOK.uri not in workspace._notebook_documents + assert NB_CELL_1.uri not in workspace._text_documents + assert NB_CELL_2.uri not in workspace._text_documents + + +def test_remove_text_document(workspace): + workspace.put_text_document(DOC) + assert workspace.get_text_document(DOC_URI).source == DOC_TEXT + workspace.remove_text_document(DOC_URI) + assert workspace.get_text_document(DOC_URI)._source is None + + +def test_update_notebook_metadata(workspace): + """Ensure we can update a notebook's metadata correctly.""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 0 + assert notebook.metadata is None + + params = types.DidChangeNotebookDocumentParams( + notebook_document=types.VersionedNotebookDocumentIdentifier( + uri=NOTEBOOK.uri, version=31 + ), + change=types.NotebookDocumentChangeEvent( + metadata={"custom": "metadata"}, + ), + ) + workspace.update_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 31 + assert notebook.metadata == {"custom": "metadata"} + + +def test_update_notebook_cell_data(workspace): + """Ensure we can update a notebook correctly when cell data changes.""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 0 + + cell_1 = notebook.cells[0] + assert cell_1.metadata is None + assert cell_1.execution_summary is None + + cell_2 = notebook.cells[1] + assert cell_2.metadata is None + assert cell_2.execution_summary is None + + params = types.DidChangeNotebookDocumentParams( + notebook_document=types.VersionedNotebookDocumentIdentifier( + uri=NOTEBOOK.uri, version=31 + ), + change=types.NotebookDocumentChangeEvent( + cells=types.NotebookDocumentChangeEventCellsType( + data=[ + types.NotebookCell( + kind=types.NotebookCellKind.Code, + document=NB_CELL_1.uri, + metadata={"slideshow": {"slide_type": "skip"}}, + execution_summary=types.ExecutionSummary( + execution_order=2, success=True + ), + ), + types.NotebookCell( + kind=types.NotebookCellKind.Code, + document=NB_CELL_2.uri, + metadata={"slideshow": {"slide_type": "note"}}, + execution_summary=types.ExecutionSummary( + execution_order=3, success=False + ), + ), + ] + ) + ), + ) + workspace.update_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 31 + + cell_1 = notebook.cells[0] + assert cell_1.metadata == {"slideshow": {"slide_type": "skip"}} + assert cell_1.execution_summary == types.ExecutionSummary( + execution_order=2, success=True + ) + + cell_2 = notebook.cells[1] + assert cell_2.metadata == {"slideshow": {"slide_type": "note"}} + assert cell_2.execution_summary == types.ExecutionSummary( + execution_order=3, success=False + ) + + +def test_update_notebook_cell_content(workspace): + """Ensure we can update a notebook correctly when the cell contents change.""" + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 0 + + cell_1 = workspace.get_text_document(NB_CELL_1.uri) + assert cell_1.version == 0 + assert cell_1.source == "# cell 1" + + cell_2 = workspace.get_text_document(NB_CELL_2.uri) + assert cell_2.version == 0 + assert cell_2.source == "# cell 2" + + params = types.DidChangeNotebookDocumentParams( + notebook_document=types.VersionedNotebookDocumentIdentifier( + uri=NOTEBOOK.uri, version=31 + ), + change=types.NotebookDocumentChangeEvent( + cells=types.NotebookDocumentChangeEventCellsType( + text_content=[ + types.NotebookDocumentChangeEventCellsTypeTextContentType( + document=types.VersionedTextDocumentIdentifier( + uri=NB_CELL_1.uri, version=13 + ), + changes=[ + types.TextDocumentContentChangeEvent_Type1( + text="new text", + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=8), + ), + ) + ], + ), + types.NotebookDocumentChangeEventCellsTypeTextContentType( + document=types.VersionedTextDocumentIdentifier( + uri=NB_CELL_2.uri, version=21 + ), + changes=[ + types.TextDocumentContentChangeEvent_Type1( + text="", + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=8), + ), + ), + types.TextDocumentContentChangeEvent_Type1( + text="other text", + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + ), + ], + ), + ] + ) + ), + ) + workspace.update_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 31 + + cell_1 = workspace.get_text_document(NB_CELL_1.uri) + assert cell_1.version == 13 + assert cell_1.source == "new text" + + cell_2 = workspace.get_text_document(NB_CELL_2.uri) + assert cell_2.version == 21 + assert cell_2.source == "other text" + + +def test_update_notebook_new_cells(workspace): + """Ensure that we can correctly add new cells to an existing notebook.""" + + params = types.DidOpenNotebookDocumentParams( + notebook_document=NOTEBOOK, + cell_text_documents=[ + NB_CELL_1, + NB_CELL_2, + ], + ) + workspace.put_notebook_document(params) + + notebook = workspace.get_notebook_document(notebook_uri=NOTEBOOK.uri) + assert notebook.version == 0 + + cell_uris = [c.document for c in notebook.cells] + assert cell_uris == [NB_CELL_1.uri, NB_CELL_2.uri] + + cell_1 = workspace.get_text_document(NB_CELL_1.uri) + assert cell_1.version == 0 + assert cell_1.source == "# cell 1" + + cell_2 = workspace.get_text_document(NB_CELL_2.uri) + assert cell_2.version == 0 + assert cell_2.source == "# cell 2" + + params = types.DidChangeNotebookDocumentParams( + notebook_document=types.VersionedNotebookDocumentIdentifier( + uri=NOTEBOOK.uri, version=31 + ), + change=types.NotebookDocumentChangeEvent( + cells=types.NotebookDocumentChangeEventCellsType( + structure=types.NotebookDocumentChangeEventCellsTypeStructureType( + array=types.NotebookCellArrayChange( + start=1, + delete_count=0, + cells=[ + types.NotebookCell( + kind=types.NotebookCellKind.Code, document=NB_CELL_3.uri + ) + ], + ), + did_open=[NB_CELL_3], + ) + ) + ), + ) + workspace.update_notebook_document(params) + + notebook = workspace.get_notebook_document(cell_uri=NB_CELL_3.uri) + assert notebook.uri == NOTEBOOK.uri + assert NB_CELL_3.uri in workspace._text_documents + + cell_uris = [c.document for c in notebook.cells] + assert cell_uris == [NB_CELL_1.uri, NB_CELL_3.uri, NB_CELL_2.uri] + + +def test_workspace_folders(): + wf1 = types.WorkspaceFolder(uri="/ws/f1", name="ws1") + wf2 = types.WorkspaceFolder(uri="/ws/f2", name="ws2") + + workspace = Workspace("/ws", workspace_folders=[wf1, wf2]) + + assert workspace.folders["/ws/f1"] is wf1 + assert workspace.folders["/ws/f2"] is wf2 + + +def test_null_workspace(): + workspace = Workspace(None) + + assert workspace.root_uri is None + assert workspace.root_path is None |