summaryrefslogtreecommitdiffstats
path: root/tests/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/client.py')
-rw-r--r--tests/client.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/tests/client.py b/tests/client.py
new file mode 100644
index 0000000..9be8752
--- /dev/null
+++ b/tests/client.py
@@ -0,0 +1,180 @@
+############################################################################
+# Original work Copyright 2017 Palantir Technologies, Inc. #
+# Original work licensed under the MIT License. #
+# See ThirdPartyNotices.txt in the project root for license information. #
+# All modifications Copyright (c) Open Law Library. All rights reserved. #
+# #
+# Licensed under the Apache License, Version 2.0 (the "License") #
+# you may not use this file except in compliance with the License. #
+# You may obtain a copy of the License at #
+# #
+# http: // www.apache.org/licenses/LICENSE-2.0 #
+# #
+# Unless required by applicable law or agreed to in writing, software #
+# distributed under the License is distributed on an "AS IS" BASIS, #
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
+# See the License for the specific language governing permissions and #
+# limitations under the License. #
+############################################################################
+import asyncio
+import logging
+import pathlib
+import sys
+from concurrent.futures import Future
+from typing import Dict
+from typing import List
+from typing import Type
+
+import pytest
+import pytest_asyncio
+from lsprotocol import types
+
+from pygls import IS_PYODIDE
+from pygls import uris
+from pygls.exceptions import JsonRpcMethodNotFound
+from pygls.lsp.client import BaseLanguageClient
+from pygls.protocol import LanguageServerProtocol
+from pygls.protocol import default_converter
+
+logger = logging.getLogger(__name__)
+
+
+class LanguageClientProtocol(LanguageServerProtocol):
+ """An extended protocol class with extra methods that are useful for testing."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self._notification_futures = {}
+
+ def _handle_notification(self, method_name, params):
+ if method_name == types.CANCEL_REQUEST:
+ self._handle_cancel_notification(params.id)
+ return
+
+ future = self._notification_futures.pop(method_name, None)
+ if future:
+ future.set_result(params)
+
+ try:
+ handler = self._get_handler(method_name)
+ self._execute_notification(handler, params)
+ except (KeyError, JsonRpcMethodNotFound):
+ logger.warning("Ignoring notification for unknown method '%s'", method_name)
+ except Exception:
+ logger.exception(
+ "Failed to handle notification '%s': %s", method_name, params
+ )
+
+ def wait_for_notification(self, method: str, callback=None):
+ future: Future = Future()
+ if callback:
+
+ def wrapper(future: Future):
+ result = future.result()
+ callback(result)
+
+ future.add_done_callback(wrapper)
+
+ self._notification_futures[method] = future
+ return future
+
+ def wait_for_notification_async(self, method: str):
+ future = self.wait_for_notification(method)
+ return asyncio.wrap_future(future)
+
+
+class LanguageClient(BaseLanguageClient):
+ """Language client used to drive test cases."""
+
+ def __init__(
+ self,
+ protocol_cls: Type[LanguageClientProtocol] = LanguageClientProtocol,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ "pygls-test-client", "v1", protocol_cls=protocol_cls, *args, **kwargs
+ )
+
+ self.diagnostics: Dict[str, List[types.Diagnostic]] = {}
+ """Used to hold any recieved diagnostics."""
+
+ self.messages: List[types.ShowMessageParams] = []
+ """Holds any received ``window/showMessage`` requests."""
+
+ self.log_messages: List[types.LogMessageParams] = []
+ """Holds any received ``window/logMessage`` requests."""
+
+ async def wait_for_notification(self, method: str):
+ """Block until a notification with the given method is received.
+
+ Parameters
+ ----------
+ method
+ The notification method to wait for, e.g. ``textDocument/publishDiagnostics``
+ """
+ return await self.protocol.wait_for_notification_async(method)
+
+
+def make_test_lsp_client() -> LanguageClient:
+ """Construct a new test client instance with the handlers needed to capture
+ additional responses from the server."""
+
+ client = LanguageClient(converter_factory=default_converter)
+
+ @client.feature(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
+ def publish_diagnostics(
+ client: LanguageClient, params: types.PublishDiagnosticsParams
+ ):
+ client.diagnostics[params.uri] = params.diagnostics
+
+ @client.feature(types.WINDOW_LOG_MESSAGE)
+ def log_message(client: LanguageClient, params: types.LogMessageParams):
+ client.log_messages.append(params)
+
+ levels = ["ERROR: ", "WARNING: ", "INFO: ", "LOG: "]
+ log_level = levels[params.type.value - 1]
+
+ print(log_level, params.message)
+
+ @client.feature(types.WINDOW_SHOW_MESSAGE)
+ def show_message(client: LanguageClient, params):
+ client.messages.append(params)
+
+ return client
+
+
+def create_client_for_server(server_name: str):
+ """Automate the process of creating a language client connected to the given server
+ and tearing it down again.
+ """
+
+ @pytest_asyncio.fixture
+ async def fixture_func():
+ if IS_PYODIDE:
+ pytest.skip("not available in pyodide")
+
+ client = make_test_lsp_client()
+ server_dir = pathlib.Path(__file__, "..", "..", "examples", "servers").resolve()
+ root_dir = pathlib.Path(__file__, "..", "..", "examples", "workspace").resolve()
+
+ await client.start_io(sys.executable, str(server_dir / server_name))
+
+ # Initialize the server
+ response = await client.initialize_async(
+ types.InitializeParams(
+ capabilities=types.ClientCapabilities(),
+ root_uri=uris.from_fs_path(root_dir),
+ )
+ )
+ assert response is not None
+
+ yield client, response
+
+ await client.shutdown_async(None)
+ client.exit(None)
+
+ await client.stop()
+
+ return fixture_func