summaryrefslogtreecommitdiffstats
path: root/third_party/rust/jsparagus/jsparagus/extension.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/jsparagus/jsparagus/extension.py')
-rw-r--r--third_party/rust/jsparagus/jsparagus/extension.py108
1 files changed, 108 insertions, 0 deletions
diff --git a/third_party/rust/jsparagus/jsparagus/extension.py b/third_party/rust/jsparagus/jsparagus/extension.py
new file mode 100644
index 0000000000..515fc68c1a
--- /dev/null
+++ b/third_party/rust/jsparagus/jsparagus/extension.py
@@ -0,0 +1,108 @@
+"""Data structure extracted from parsing the EDSL which are added within the
+Rust code."""
+
+from __future__ import annotations
+# mypy: disallow-untyped-defs, disallow-incomplete-defs, disallow-untyped-calls
+
+import typing
+import os
+
+from dataclasses import dataclass
+from .utils import keep_until
+from .grammar import Element, Grammar, LenientNt, NtDef, Production
+
+
+@dataclass(frozen=True)
+class ImplFor:
+ __slots__ = ['param', 'trait', 'for_type']
+ param: str
+ trait: str
+ for_type: str
+
+
+def eq_productions(grammar: Grammar, prod1: Production, prod2: Production) -> bool:
+ s1 = tuple(e for e in prod1.body if grammar.is_shifted_element(e))
+ s2 = tuple(e for e in prod2.body if grammar.is_shifted_element(e))
+ return s1 == s2
+
+
+def merge_productions(grammar: Grammar, prod1: Production, prod2: Production) -> Production:
+ # Consider all shifted elements as non-moveable elements, and insert other
+ # around these.
+ assert eq_productions(grammar, prod1, prod2)
+ l1 = list(prod1.body)
+ l2 = list(prod2.body)
+ body: typing.List[Element] = []
+ while l1 != [] and l2 != []:
+ front1 = list(keep_until(l1, grammar.is_shifted_element))
+ front2 = list(keep_until(l2, grammar.is_shifted_element))
+ assert front1[-1] == front2[-1]
+ l1 = l1[len(front1):]
+ l2 = l2[len(front2):]
+ if len(front1) == 1:
+ body = body + front2
+ elif len(front2) == 1:
+ body = body + front1
+ else:
+ raise ValueError("We do not know how to sort operations yet.")
+ return prod1.copy_with(body=body)
+
+
+@dataclass(frozen=True)
+class ExtPatch:
+ "Patch an existing grammar rule by adding Code"
+
+ prod: typing.Tuple[LenientNt, str, NtDef]
+
+ def apply_patch(
+ self,
+ filename: os.PathLike,
+ grammar: Grammar,
+ nonterminals: typing.Dict[LenientNt, NtDef]
+ ) -> None:
+ # - name: non-terminal.
+ # - namespace: ":" for syntactic or "::" for lexical. Always ":" as
+ # defined by rust_nt_def.
+ # - nt_def: A single non-terminal definition with a single production.
+ (name, namespace, nt_def) = self.prod
+ gnt_def = nonterminals[name]
+ # Find a matching production in the grammar.
+ assert nt_def.params == gnt_def.params
+ new_rhs_list = []
+ assert len(nt_def.rhs_list) == 1
+ patch_prod = nt_def.rhs_list[0]
+ applied = False
+ for grammar_prod in gnt_def.rhs_list:
+ if eq_productions(grammar, grammar_prod, patch_prod):
+ grammar_prod = merge_productions(grammar, grammar_prod, patch_prod)
+ applied = True
+ new_rhs_list.append(grammar_prod)
+ if not applied:
+ raise ValueError("{}: Unable to find a matching production for {} in the grammar:\n {}"
+ .format(filename, name, grammar.production_to_str(name, patch_prod)))
+ result = gnt_def.with_rhs_list(new_rhs_list)
+ nonterminals[name] = result
+
+
+@dataclass
+class GrammarExtension:
+ """A collection of grammar extensions, with added code, added traits for the
+ action functions.
+
+ """
+
+ target: None
+ grammar: typing.List[ExtPatch]
+ filename: os.PathLike
+
+ def apply_patch(
+ self,
+ grammar: Grammar,
+ nonterminals: typing.Dict[LenientNt, NtDef]
+ ) -> None:
+ # A grammar extension is composed of multiple production patches.
+ for ext in self.grammar:
+ if isinstance(ext, ExtPatch):
+ ext.apply_patch(self.filename, grammar, nonterminals)
+ else:
+ raise ValueError("Extension of type {} not yet supported.".format(ext.__class__))