summaryrefslogtreecommitdiffstats
path: root/generator/plugins/rust/rust_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'generator/plugins/rust/rust_utils.py')
-rw-r--r--generator/plugins/rust/rust_utils.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/generator/plugins/rust/rust_utils.py b/generator/plugins/rust/rust_utils.py
new file mode 100644
index 0000000..d817070
--- /dev/null
+++ b/generator/plugins/rust/rust_utils.py
@@ -0,0 +1,68 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import pathlib
+from typing import List
+
+from generator import model
+
+from .rust_commons import TypeData, generate_commons
+from .rust_enum import generate_enums
+from .rust_file_header import license_header
+from .rust_lang_utils import lines_to_comments
+from .rust_structs import (
+ generate_notifications,
+ generate_requests,
+ generate_structures,
+ generate_type_aliases,
+)
+
+PACKAGE_DIR_NAME = "lsprotocol"
+
+
+def generate_from_spec(spec: model.LSPModel, output_dir: str) -> None:
+ code = generate_package_code(spec)
+
+ output_path = pathlib.Path(output_dir, PACKAGE_DIR_NAME)
+ if not output_path.exists():
+ output_path.mkdir(parents=True, exist_ok=True)
+ (output_path / "src").mkdir(parents=True, exist_ok=True)
+
+ for file_name in code:
+ (output_path / file_name).write_text(code[file_name], encoding="utf-8")
+
+
+def generate_package_code(spec: model.LSPModel) -> List[str]:
+ return {
+ "src/lib.rs": generate_lib_rs(spec),
+ }
+
+
+def generate_lib_rs(spec: model.LSPModel) -> List[str]:
+ lines = lines_to_comments(license_header())
+ lines += [
+ "",
+ "// ****** THIS IS A GENERATED FILE, DO NOT EDIT. ******",
+ "// Steps to generate:",
+ "// 1. Checkout https://github.com/microsoft/lsprotocol",
+ "// 2. Install nox: `python -m pip install nox`",
+ "// 3. Run command: `python -m nox --session build_lsp`",
+ "",
+ ]
+ lines += [
+ "use serde::{Serialize, Deserialize};",
+ "use std::collections::HashMap;",
+ "use rust_decimal::Decimal;" "",
+ ]
+
+ type_data = TypeData()
+ generate_commons(spec, type_data)
+ generate_enums(spec.enumerations, type_data)
+
+ generate_type_aliases(spec, type_data)
+ generate_structures(spec, type_data)
+ generate_notifications(spec, type_data)
+ generate_requests(spec, type_data)
+
+ lines += type_data.get_lines()
+ return "\n".join(lines)