diff options
Diffstat (limited to 'scripts/generate_client.py')
-rw-r--r-- | scripts/generate_client.py | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/scripts/generate_client.py b/scripts/generate_client.py new file mode 100644 index 0000000..6aab8f2 --- /dev/null +++ b/scripts/generate_client.py @@ -0,0 +1,207 @@ +"""Script to automatically generate a lanaguge client from `lsprotocol` type definitons +""" +import argparse +import inspect +import pathlib +import re +import sys +import textwrap +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type + +from lsprotocol._hooks import _resolve_forward_references +from lsprotocol.types import METHOD_TO_TYPES +from lsprotocol.types import message_direction + +cli = argparse.ArgumentParser( + description="generate language client from lsprotocol types." +) +cli.add_argument("-o", "--output", default=None) + + +def write_imports(imports: Set[Tuple[str, str]]) -> str: + lines = [] + + for import_ in sorted(list(imports), key=lambda i: (i[0], i[1])): + if isinstance(import_, tuple): + mod, name = import_ + lines.append(f"from {mod} import {name}") + continue + + lines.append(f"import {import_}") + + return "\n".join(lines) + + +def to_snake_case(string: str) -> str: + return "".join(f"_{c.lower()}" if c.isupper() else c for c in string) + + +def write_notification( + method: str, + request: Type, + params: Optional[Type], + imports: Set[Tuple[str, str]], +) -> str: + python_name = to_snake_case(method).replace("/", "_").replace("$_", "") + + if params is None: + param_name = "None" + param_mod = "" + else: + param_mod, param_name = params.__module__, params.__name__ + param_mod = param_mod.replace("lsprotocol.types", "types") + "." + + return "\n".join( + [ + f"def {python_name}(self, params: {param_mod}{param_name}) -> None:", + f' """Send a :lsp:`{method}` notification.', + "", + textwrap.indent(inspect.getdoc(request) or "", " "), + ' """', + " if self.stopped:", + ' raise RuntimeError("Client has been stopped.")', + "", + f' self.protocol.notify("{method}", params)', + "", + ] + ) + + +def get_response_type(response: Type, imports: Set[Tuple[str, str]]) -> str: + # Find the response type. + result_field = [f for f in response.__attrs_attrs__ if f.name == "result"][0] + result = re.sub(r"<class '([\w.]+)'>", r"\1", str(result_field.type)) + result = re.sub(r"ForwardRef\('([\w.]+)'\)", r"lsprotocol.types.\1", result) + result = result.replace("NoneType", "None") + + # Replace any typing imports with their short name. + for match in re.finditer(r"typing.([\w]+)", result): + imports.add(("typing", match.group(1))) + + result = result.replace("lsprotocol.types.", "types.") + result = result.replace("typing.", "") + + return result + + +def write_method( + method: str, + request: Type, + params: Optional[Type], + response: Type, + imports: Set[Tuple[str, str]], +) -> str: + python_name = to_snake_case(method).replace("/", "_").replace("$_", "") + + if params is None: + param_name = "None" + param_mod = "" + else: + param_mod, param_name = params.__module__, params.__name__ + param_mod = param_mod.replace("lsprotocol.types", "types") + "." + + result_type = get_response_type(response, imports) + + return "\n".join( + [ + f"def {python_name}(", + " self,", + f" params: {param_mod}{param_name},", + f" callback: Optional[Callable[[{result_type}], None]] = None,", + ") -> Future:", + f' """Make a :lsp:`{method}` request.', + "", + textwrap.indent(inspect.getdoc(request) or "", " "), + ' """', + " if self.stopped:", + ' raise RuntimeError("Client has been stopped.")', + "", + f' return self.protocol.send_request("{method}", params, callback)', + "", + f"async def {python_name}_async(", + " self,", + f" params: {param_mod}{param_name},", + f") -> {result_type}:", + f' """Make a :lsp:`{method}` request.', + "", + textwrap.indent(inspect.getdoc(request) or "", " "), + ' """', + " if self.stopped:", + ' raise RuntimeError("Client has been stopped.")', + "", + f' return await self.protocol.send_request_async("{method}", params)', + "", + ] + ) + + +def generate_client() -> str: + methods = [] + imports = { + ("concurrent.futures", "Future"), + ("lsprotocol", "types"), + ("pygls.protocol", "LanguageServerProtocol"), + ("pygls.protocol", "default_converter"), + ("pygls.client", "JsonRPCClient"), + ("typing", "Callable"), + ("typing", "Optional"), + } + + for method_name, types in METHOD_TO_TYPES.items(): + # Skip any requests that come from the server. + if message_direction(method_name) == "serverToClient": + continue + + request, response, params, _ = types + + if response is None: + method = write_notification(method_name, request, params, imports) + else: + method = write_method(method_name, request, params, response, imports) + + methods.append(textwrap.indent(method, " ")) + + code = [ + "# GENERATED FROM scripts/gen-client.py -- DO NOT EDIT", + "# flake8: noqa", + write_imports(imports), + "", + "", + "class BaseLanguageClient(JsonRPCClient):", + "", + " def __init__(", + " self,", + " name: str,", + " version: str,", + " protocol_cls=LanguageServerProtocol,", + " converter_factory=default_converter,", + " **kwargs,", + " ):", + " self.name = name", + " self.version = version", + " super().__init__(protocol_cls, converter_factory, **kwargs)", + "", + *methods, + ] + return "\n".join(code) + + +def main(): + args = cli.parse_args() + + # Make sure all the type annotations in lsprotocol are resolved correctly. + _resolve_forward_references() + client = generate_client() + + if args.output is None: + sys.stdout.write(client) + else: + output = pathlib.Path(args.output) + output.write_text(client) + + +if __name__ == "__main__": + main() |