diff options
Diffstat (limited to 'third_party/rust/jsparagus/jsparagus/emit/rust.py')
-rw-r--r-- | third_party/rust/jsparagus/jsparagus/emit/rust.py | 903 |
1 files changed, 903 insertions, 0 deletions
diff --git a/third_party/rust/jsparagus/jsparagus/emit/rust.py b/third_party/rust/jsparagus/jsparagus/emit/rust.py new file mode 100644 index 0000000000..ca8382954e --- /dev/null +++ b/third_party/rust/jsparagus/jsparagus/emit/rust.py @@ -0,0 +1,903 @@ +"""Emit code and parser tables in Rust.""" + +import json +import re +import unicodedata +import sys +import itertools +import collections +from contextlib import contextmanager + +from ..runtime import (ERROR, ErrorToken, SPECIAL_CASE_TAG) +from ..ordered import OrderedSet + +from ..grammar import (Some, Nt, InitNt, End, ErrorSymbol) +from ..actions import (Accept, Action, Replay, Unwind, Reduce, CheckNotOnNewLine, FilterStates, + PushFlag, PopFlag, FunCall, Seq) + +from .. import types + + +TERMINAL_NAMES = { + '&&=': 'LogicalAndAssign', + '||=': 'LogicalOrAssign', + '??=': 'CoalesceAssign', + '{': 'OpenBrace', + '}': 'CloseBrace', + '(': 'OpenParenthesis', + ')': 'CloseParenthesis', + '[': 'OpenBracket', + ']': 'CloseBracket', + '+': 'Plus', + '-': 'Minus', + '~': 'BitwiseNot', + '!': 'LogicalNot', + '++': 'Increment', + '--': 'Decrement', + ':': 'Colon', + '=>': 'Arrow', + '=': 'EqualSign', + '*=': 'MultiplyAssign', + '/=': 'DivideAssign', + '%=': 'RemainderAssign', + '+=': 'AddAssign', + '-=': 'SubtractAssign', + '<<=': 'LeftShiftAssign', + '>>=': 'SignedRightShiftAssign', + '>>>=': 'UnsignedRightShiftAssign', + '&=': 'BitwiseAndAssign', + '^=': 'BitwiseXorAssign', + '|=': 'BitwiseOrAssign', + '**=': 'ExponentiateAssign', + '.': 'Dot', + '**': 'Exponentiate', + '?.': 'OptionalChain', + '?': 'QuestionMark', + '??': 'Coalesce', + '*': 'Star', + '/': 'Divide', + '%': 'Remainder', + '<<': 'LeftShift', + '>>': 'SignedRightShift', + '>>>': 'UnsignedRightShift', + '<': 'LessThan', + '>': 'GreaterThan', + '<=': 'LessThanOrEqualTo', + '>=': 'GreaterThanOrEqualTo', + '==': 'LaxEqual', + '!=': 'LaxNotEqual', + '===': 'StrictEqual', + '!==': 'StrictNotEqual', + '&': 'BitwiseAnd', + '^': 'BitwiseXor', + '|': 'BitwiseOr', + '&&': 'LogicalAnd', + '||': 'LogicalOr', + ',': 'Comma', + '...': 'Ellipsis', +} + + +@contextmanager +def indent(writer): + """This function is meant to be used with the `with` keyword of python, and + allow the user of it to add an indentation level to the code which is + enclosed in the `with` statement. + + This has the advantage that the indentation of the python code is reflected + to the generated code when `with indent(self):` is used. """ + writer.indent += 1 + yield None + writer.indent -= 1 + +def extract_ranges(iterator): + """Given a sorted iterator of integer, yield the contiguous ranges""" + # Identify contiguous ranges of states. + ranges = collections.defaultdict(list) + # A sorted list of contiguous integers implies that elements are separated + # by 1, as well as their indexes. Thus we can categorize them into buckets + # of contiguous integers using the base, which is the value v from which we + # remove the index i. + for i, v in enumerate(iterator): + ranges[v - i].append(v) + for l in ranges.values(): + yield (l[0], l[-1]) + +def rust_range(riter): + """Prettify a list of tuple of (min, max) of matched ranges into Rust + syntax.""" + def minmax_join(rmin, rmax): + if rmin == rmax: + return str(rmin) + else: + return "{}..={}".format(rmin, rmax) + return " | ".join(minmax_join(rmin, rmax) for rmin, rmax in riter) + +class RustActionWriter: + """Write epsilon state transitions for a given action function.""" + ast_builder = types.Type("AstBuilderDelegate", (types.Lifetime("alloc"),)) + + def __init__(self, writer, mode, traits, indent): + self.states = writer.states + self.writer = writer + self.mode = mode + self.traits = traits + self.indent = indent + self.has_ast_builder = self.ast_builder in traits + self.used_variables = set() + self.replay_args = [] + + def implement_trait(self, funcall): + "Returns True if this function call should be encoded" + ty = funcall.trait + if ty.name == "AstBuilder": + return "AstBuilderDelegate<'alloc>" in map(str, self.traits) + if ty in self.traits: + return True + if len(ty.args) == 0: + return ty.name in map(lambda t: t.name, self.traits) + return False + + def reset(self, act): + "Traverse all action to collect preliminary information." + self.used_variables = set(self.collect_uses(act)) + + def collect_uses(self, act): + "Generator which visit all used variables." + assert isinstance(act, Action) + if isinstance(act, (Reduce, Unwind)): + yield "value" + elif isinstance(act, FunCall): + arg_offset = act.offset + if arg_offset < 0: + # See write_funcall. + arg_offset = 0 + def map_with_offset(args): + for a in args: + if isinstance(a, int): + yield a + arg_offset + if isinstance(a, str): + yield a + elif isinstance(a, Some): + for offset in map_with_offset([a.inner]): + yield offset + if self.implement_trait(act): + for var in map_with_offset(act.args): + yield var + elif isinstance(act, Seq): + for a in act.actions: + for var in self.collect_uses(a): + yield var + + def write(self, string, *format_args): + "Delegate to the RustParserWriter.write function" + self.writer.write(self.indent, string, *format_args) + + def write_state_transitions(self, state, replay_args): + "Given a state, generate the code corresponding to all outgoing epsilon edges." + try: + self.replay_args = replay_args + assert not state.is_inconsistent() + assert len(list(state.shifted_edges())) == 0 + for ctx in self.writer.parse_table.debug_context(state.index, None): + self.write("// {}", ctx) + first, dest = next(state.edges(), (None, None)) + if first is None: + return + self.reset(first) + if first.is_condition(): + self.write_condition(state, first) + else: + assert len(list(state.edges())) == 1 + self.write_action(first, dest) + except Exception as exc: + print("Error while writing code for {}\n\n".format(state)) + self.writer.parse_table.debug_info = True + print(self.writer.parse_table.debug_context(state.index, "\n", "# ")) + raise exc + + def write_replay_args(self, n): + rp_args = self.replay_args[:n] + rp_stck = self.replay_args[n:] + for tv in rp_stck: + self.write("parser.replay({});", tv) + return rp_args + + + def write_epsilon_transition(self, dest): + # Replay arguments which are not accepted as input of the next state. + dest = self.states[dest] + rp_args = self.write_replay_args(dest.arguments) + self.write("// --> {}", dest.index) + if dest.index >= self.writer.shift_count: + self.write("{}_{}(parser{})", self.mode, dest.index, "".join(map(lambda v: ", " + v, rp_args))) + else: + assert dest.arguments == 0 + self.write("parser.epsilon({});", dest.index) + self.write("Ok(false)") + + def write_condition(self, state, first_act): + "Write code to test a conditions, and dispatch to the matching destination" + # NOTE: we already asserted that this state is consistent, this implies + # that the first state check the same variables as all remaining + # states. Thus we use the first action to produce the match statement. + assert isinstance(first_act, Action) + assert first_act.is_condition() + if isinstance(first_act, CheckNotOnNewLine): + # TODO: At the moment this is Action is implemented as a single + # operation with a single destination. However, we should implement + # it in the future as 2 branches, one which is verifying the lack + # of new lines, and one which is shifting an extra error token. + # This might help remove the overhead of backtracking in addition + # to make this backtracking visible through APS. + assert len(list(state.edges())) == 1 + act, dest = next(state.edges()) + assert len(self.replay_args) == 0 + assert -act.offset > 0 + self.write("// {}", str(act)) + self.write("if !parser.check_not_on_new_line({})? {{", -act.offset) + with indent(self): + self.write("return Ok(false);") + self.write("}") + self.write_epsilon_transition(dest) + elif isinstance(first_act, FilterStates): + if len(state.epsilon) == 1: + # This is an attempt to avoid huge unending compilations. + _, dest = next(iter(state.epsilon), (None, None)) + pattern = rust_range(extract_ranges(first_act.states)) + self.write("// parser.top_state() in ({})", pattern) + self.write_epsilon_transition(dest) + else: + self.write("match parser.top_state() {") + with indent(self): + # Consider the branch which has the largest number of + # potential top-states to be most likely, and therefore the + # default branch to go to if all other fail to match. + default_weight = max(len(act.states) for act, dest in state.edges()) + default_states = [] + default_dest = None + for act, dest in state.edges(): + assert first_act.check_same_variable(act) + if default_dest is None and default_weight == len(act.states): + # This range has the same weight as the default + # branch. Ignore it and use it as the default + # branch which would be generated at the end. + default_states = act.states + default_dest = dest + continue + pattern = rust_range(extract_ranges(act.states)) + self.write("{} => {{", pattern) + with indent(self): + self.write_epsilon_transition(dest) + self.write("}") + # Generate code for the default branch, which got skipped + # while producing the loop. + self.write("_ => {") + with indent(self): + pattern = rust_range(extract_ranges(default_states)) + self.write("// {}", pattern) + self.write_epsilon_transition(default_dest) + self.write("}") + self.write("}") + else: + raise ValueError("Unexpected action type") + + def write_action(self, act, dest): + assert isinstance(act, Action) + assert not act.is_condition() + is_packed = {} + + # Do not pop any of the stack elements if the reduce action has an + # accept function call. Ideally we should be returning the result + # instead of keeping it on the parser stack. + if act.update_stack() and not act.contains_accept(): + stack_diff = act.update_stack_with() + start = 0 + depth = stack_diff.pop + args = len(self.replay_args) + replay = stack_diff.replay + if replay < 0: + # At the moment, we do not handle having more arguments than + # what is being popped and replay, thus write back the extra + # arguments and continue. + if stack_diff.pop + replay < 0: + self.replay_args = self.write_replay_args(replay) + replay = 0 + if replay + stack_diff.pop - args > 0: + assert (replay >= 0 and args == 0) or \ + (replay == 0 and args >= 0) + if replay > 0: + # At the moment, assume that arguments are only added once we + # consumed all replayed terms. Thus the replay_args can only be + # non-empty once replay is 0. Otherwise some of the replay_args + # would have to be replayed. + assert args == 0 + self.write("parser.rewind({});", replay) + start = replay + depth += start + + inputs = [] + for i in range(start, depth): + name = 's{}'.format(i + 1) + if i + 1 not in self.used_variables: + name = '_' + name + inputs.append(name) + if stack_diff.pop > 0: + args_pop = min(len(self.replay_args), stack_diff.pop) + # Pop by moving arguments of the action function. + for i, name in enumerate(inputs[:args_pop]): + self.write("let {} = {};", name, self.replay_args[-i - 1]) + # Pop by removing elements from the parser stack. + for name in inputs[args_pop:]: + self.write("let {} = parser.pop();", name) + if args_pop > 0: + del self.replay_args[-args_pop:] + + if isinstance(act, Seq): + for a in act.actions: + self.write_single_action(a, is_packed) + if a.contains_accept(): + break + else: + self.write_single_action(act, is_packed) + + # If we fallthrough the execution of the action, then generate an + # epsilon transition. + if act.follow_edge() and not act.contains_accept(): + assert 0 <= dest < self.writer.shift_count + self.writer.action_count + self.write_epsilon_transition(dest) + + def write_single_action(self, act, is_packed): + self.write("// {}", str(act)) + if isinstance(act, Replay): + self.write_replay(act) + elif isinstance(act, (Reduce, Unwind)): + self.write_reduce(act, is_packed) + elif isinstance(act, Accept): + self.write_accept() + elif isinstance(act, PushFlag): + raise ValueError("NYI: PushFlag action") + elif isinstance(act, PopFlag): + raise ValueError("NYI: PopFlag action") + elif isinstance(act, FunCall): + self.write_funcall(act, is_packed) + else: + raise ValueError("Unexpected action type") + + def write_replay(self, act): + assert len(self.replay_args) == 0 + for shift_state in act.replay_steps: + self.write("parser.shift_replayed({});", shift_state) + + def write_reduce(self, act, is_packed): + value = "value" + if value in is_packed: + packed = is_packed[value] + else: + packed = False + value = "None" + + if packed: + # Extract the StackValue from the packed TermValue + value = "{}.value".format(value) + elif self.has_ast_builder: + # Convert into a StackValue + value = "TryIntoStack::try_into_stack({})?".format(value) + else: + # Convert into a StackValue (when no ast-builder) + value = "value" + + stack_diff = act.update_stack_with() + assert stack_diff.nt is not None + self.write("let term = NonterminalId::{}.into();", + self.writer.nonterminal_to_camel(stack_diff.nt)) + if value != "value": + self.write("let value = {};", value) + self.write("let reduced = TermValue { term, value };") + self.replay_args.append("reduced") + + def write_accept(self): + self.write("return Ok(true);") + + def write_funcall(self, act, is_packed): + arg_offset = act.offset + if arg_offset < 0: + # NOTE: When replacing replayed stack elements by arguments, the + # offset is reduced by -1, and can become negative for cases where + # we read the value associated with an argument instead of the + # value read from the stack. However, write_action shift everything + # as-if we had replayed all the necessary terms, and therefore + # variables are named as-if the offset were 0. + arg_offset = 0 + + def no_unpack(val): + return val + + def unpack(val): + if val in is_packed: + packed = is_packed[val] + else: + packed = True + if packed: + return "{}.value.to_ast()?".format(val) + return val + + def map_with_offset(args, unpack): + get_value = "s{}" + for a in args: + if isinstance(a, int): + yield unpack(get_value.format(a + arg_offset)) + elif isinstance(a, str): + yield unpack(a) + elif isinstance(a, Some): + yield "Some({})".format(next(map_with_offset([a.inner], unpack))) + elif a is None: + yield "None" + else: + raise ValueError(a) + + packed = False + # If the variable is used, then generate the let binding. + set_var = "" + if act.set_to in self.used_variables: + set_var = "let {} = ".format(act.set_to) + + # If the function cannot be call as the generated action function does + # not use the trait on which this function is implemented, then replace + # the value by `()`. + if not self.implement_trait(act): + self.write("{}();", set_var) + return + + # NOTE: Currently "AstBuilder" is implemented through the + # AstBuilderDelegate which returns a mutable reference to the + # AstBuilder. This would call the specific special case method to get + # the actual AstBuilder. + delegate = "" + if str(act.trait) == "AstBuilder": + delegate = "ast_builder_refmut()." + + # NOTE: Currently "AstBuilder" functions are made fallible + # using the fallible_methods taken from some Rust code + # which extract this information to produce a JSON file. + forward_errors = "" + if act.fallible or act.method in self.writer.fallible_methods: + forward_errors = "?" + + # By default generate a method call, with the method name. However, + # there is a special case for the "id" function which is an artifact, + # which does not have to unpack the content of its argument. + value = "parser.{}{}({})".format( + delegate, act.method, + ", ".join(map_with_offset(act.args, unpack))) + packed = False + if act.method == "id": + assert len(act.args) == 1 + value = next(map_with_offset(act.args, no_unpack)) + if isinstance(act.args[0], str): + packed = is_packed[act.args[0]] + else: + assert isinstance(act.args[0], int) + packed = True + + self.write("{}{}{};", set_var, value, forward_errors) + is_packed[act.set_to] = packed + + +class RustParserWriter: + def __init__(self, out, pt, fallible_methods): + self.out = out + self.fallible_methods = fallible_methods + assert pt.exec_modes is not None + self.parse_table = pt + self.states = pt.states + self.shift_count = pt.count_shift_states() + self.action_count = pt.count_action_states() + self.action_from_shift_count = pt.count_action_from_shift_states() + self.init_state_map = pt.named_goals + self.terminals = list(OrderedSet(pt.terminals)) + # This extra terminal is used to represent any ErrorySymbol transition, + # knowing that we assert that there is only one ErrorSymbol kind per + # state. + self.terminals.append("ErrorToken") + self.nonterminals = list(OrderedSet(pt.nonterminals)) + + def emit(self): + self.header() + self.terms_id() + self.shift() + self.error_codes() + self.check_camel_case() + self.actions() + self.entry() + + def write(self, indentation, string, *format_args): + if len(format_args) == 0: + formatted = string + else: + formatted = string.format(*format_args) + self.out.write(" " * indentation + formatted + "\n") + + def header(self): + self.write(0, "// WARNING: This file is autogenerated.") + self.write(0, "") + self.write(0, "use crate::ast_builder::AstBuilderDelegate;") + self.write(0, "use crate::stack_value_generated::{StackValue, TryIntoStack};") + self.write(0, "use crate::traits::{TermValue, ParserTrait};") + self.write(0, "use crate::error::Result;") + traits = OrderedSet() + for mode_traits in self.parse_table.exec_modes.values(): + traits |= mode_traits + traits = list(traits) + traits = [ty for ty in traits if ty.name != "AstBuilderDelegate"] + traits = [ty for ty in traits if ty.name != "ParserTrait"] + if traits == []: + pass + elif len(traits) == 1: + self.write(0, "use crate::traits::{};", traits[0].name) + else: + self.write(0, "use crate::traits::{{{}}};", ", ".join(ty.name for ty in traits)) + self.write(0, "") + self.write(0, "const ERROR: i64 = {};", hex(ERROR)) + self.write(0, "") + + def terminal_name(self, value): + if isinstance(value, End) or value is None: + return "End" + elif isinstance(value, ErrorSymbol) or value is ErrorToken: + return "ErrorToken" + elif value in TERMINAL_NAMES: + return TERMINAL_NAMES[value] + elif value.isalpha(): + if value.islower(): + return value.capitalize() + else: + return value + else: + raw_name = " ".join((unicodedata.name(c) for c in value)) + snake_case = raw_name.replace("-", " ").replace(" ", "_").lower() + camel_case = self.to_camel_case(snake_case) + return camel_case + + def terminal_name_camel(self, value): + return self.to_camel_case(self.terminal_name(value)) + + def terms_id(self): + self.write(0, "#[derive(Copy, Clone, Debug, PartialEq)]") + self.write(0, "#[repr(u32)]") + self.write(0, "pub enum TerminalId {") + for i, t in enumerate(self.terminals): + name = self.terminal_name(t) + self.write(1, "{} = {}, // {}", name, i, repr(t)) + self.write(0, "}") + self.write(0, "") + self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]") + self.write(0, "#[repr(u32)]") + self.write(0, "pub enum NonterminalId {") + offset = len(self.terminals) + for i, nt in enumerate(self.nonterminals): + self.write(1, "{} = {},", self.nonterminal_to_camel(nt), i + offset) + self.write(0, "}") + self.write(0, "") + self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]") + self.write(0, "pub struct Term(u32);") + self.write(0, "") + self.write(0, "impl Term {") + self.write(1, "pub fn is_terminal(&self) -> bool {") + self.write(2, "self.0 < {}", offset) + self.write(1, "}") + self.write(1, "pub fn to_terminal(&self) -> TerminalId {") + self.write(2, "assert!(self.is_terminal());") + self.write(2, "unsafe { std::mem::transmute(self.0) }") + self.write(1, "}") + self.write(0, "}") + self.write(0, "") + self.write(0, "impl From<TerminalId> for Term {") + self.write(1, "fn from(t: TerminalId) -> Self {") + self.write(2, "Term(t as _)") + self.write(1, "}") + self.write(0, "}") + self.write(0, "") + self.write(0, "impl From<NonterminalId> for Term {") + self.write(1, "fn from(nt: NonterminalId) -> Self {") + self.write(2, "Term(nt as _)") + self.write(1, "}") + self.write(0, "}") + self.write(0, "") + self.write(0, "impl From<Term> for usize {") + self.write(1, "fn from(term: Term) -> Self {") + self.write(2, "term.0 as _") + self.write(1, "}") + self.write(0, "}") + self.write(0, "") + self.write(0, "impl From<Term> for &'static str {") + self.write(1, "fn from(term: Term) -> Self {") + self.write(2, "match term.0 {") + for i, t in enumerate(self.terminals): + self.write(3, "{} => &\"{}\",", i, repr(t)) + for j, nt in enumerate(self.nonterminals): + i = j + offset + self.write(3, "{} => &\"{}\",", i, str(nt.name)) + self.write(3, "_ => panic!(\"unknown Term\")", i, str(nt.name)) + self.write(2, "}") + self.write(1, "}") + self.write(0, "}") + self.write(0, "") + + def shift(self): + self.write(0, "#[rustfmt::skip]") + width = len(self.terminals) + len(self.nonterminals) + num_shifted_edges = 0 + + def state_get(state, t): + nonlocal num_shifted_edges + res = state.get(t, "ERROR") + if res == "ERROR": + error_symbol = state.get_error_symbol() + if t == "ErrorToken" and error_symbol: + res = state[error_symbol] + num_shifted_edges += 1 + else: + num_shifted_edges += 1 + return res + + self.write(0, "static SHIFT: [i64; {}] = [", self.shift_count * width) + assert self.terminals[-1] == "ErrorToken" + for i, state in enumerate(self.states[:self.shift_count]): + num_shifted_edges = 0 + self.write(1, "// {}.", i) + for ctx in self.parse_table.debug_context(state.index, None): + self.write(1, "// {}", ctx) + self.write(1, "{}", + ' '.join("{},".format(state_get(state, t)) for t in self.terminals)) + self.write(1, "{}", + ' '.join("{},".format(state_get(state, t)) for t in self.nonterminals)) + try: + assert sum(1 for _ in state.shifted_edges()) == num_shifted_edges + except Exception: + print("Some edges are not encoded.") + print("List of terminals: {}".format(', '.join(map(repr, self.terminals)))) + print("List of nonterminals: {}".format(', '.join(map(repr, self.nonterminals)))) + print("State having the issue: {}".format(str(state))) + raise + self.write(0, "];") + self.write(0, "") + + def render_action(self, action): + if isinstance(action, tuple): + if action[0] == 'IfSameLine': + _, a1, a2 = action + if a1 is None: + a1 = 'ERROR' + if a2 is None: + a2 = 'ERROR' + index = self.add_special_case( + "if token.is_on_new_line { %s } else { %s }" + % (a2, a1)) + else: + raise ValueError("unrecognized kind of special case: {!r}".format(action)) + return SPECIAL_CASE_TAG + index + elif action == 'ERROR': + return action + else: + assert isinstance(action, int) + return action + + def emit_special_cases(self): + self.write(0, "static SPECIAL_CASES: [fn(&Token) -> i64; {}] = [", + len(self.special_cases)) + for i, code in enumerate(self.special_cases): + self.write(1, "|token| {{ {} }},", code) + self.write(0, "];") + self.write(0, "") + + def error_codes(self): + self.write(0, "#[derive(Clone, Copy, Debug, PartialEq)]") + self.write(0, "pub enum ErrorCode {") + error_symbols = (s.get_error_symbol() for s in self.states[:self.shift_count]) + error_codes = (e.error_code for e in error_symbols if e is not None) + for error_code in OrderedSet(error_codes): + self.write(1, "{},", self.to_camel_case(error_code)) + self.write(0, "}") + self.write(0, "") + + self.write(0, "static STATE_TO_ERROR_CODE: [Option<ErrorCode>; {}] = [", + self.shift_count) + for i, state in enumerate(self.states[:self.shift_count]): + error_symbol = state.get_error_symbol() + if error_symbol is None: + self.write(1, "None,") + else: + self.write(1, "// {}.", i) + for ctx in self.parse_table.debug_context(state.index, None): + self.write(1, "// {}", ctx) + self.write(1, "Some(ErrorCode::{}),", + self.to_camel_case(error_symbol.error_code)) + self.write(0, "];") + self.write(0, "") + + def nonterminal_to_snake(self, ident): + if isinstance(ident, Nt): + if isinstance(ident.name, InitNt): + name = "Start" + ident.name.goal.name + else: + name = ident.name + base_name = self.to_snek_case(name) + args = ''.join((("_" + self.to_snek_case(name)) + for name, value in ident.args if value)) + return base_name + args + else: + assert isinstance(ident, str) + return self.to_snek_case(ident) + + def nonterminal_to_camel(self, nt): + return self.to_camel_case(self.nonterminal_to_snake(nt)) + + def to_camel_case(self, ident): + if '_' in ident: + return ''.join(word.capitalize() for word in ident.split('_')) + elif ident.islower(): + return ident.capitalize() + else: + return ident + + def check_camel_case(self): + seen = {} + for nt in self.nonterminals: + cc = self.nonterminal_to_camel(nt) + if cc in seen: + raise ValueError("{} and {} have the same camel-case spelling ({})".format( + seen[cc], nt, cc)) + seen[cc] = nt + + def to_snek_case(self, ident): + # https://stackoverflow.com/questions/1175208 + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', ident) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + def type_to_rust(self, ty, namespace="", boxed=False): + """ + Convert a jsparagus type (see types.py) to Rust. + + Pass boxed=True if the type needs to be boxed. + """ + if isinstance(ty, types.Lifetime): + assert not boxed + rty = "'" + ty.name + elif ty == types.UnitType: + assert not boxed + rty = '()' + elif ty == types.TokenType: + rty = "Token" + elif ty.name == 'Option' and len(ty.args) == 1: + # We auto-translate `Box<Option<T>>` to `Option<Box<T>>` since + # that's basically the same thing but more efficient. + [arg] = ty.args + return 'Option<{}>'.format(self.type_to_rust(arg, namespace, boxed)) + elif ty.name == 'Vec' and len(ty.args) == 1: + [arg] = ty.args + rty = "Vec<'alloc, {}>".format(self.type_to_rust(arg, namespace, boxed=False)) + else: + if namespace == "": + rty = ty.name + else: + rty = namespace + '::' + ty.name + if ty.args: + rty += '<{}>'.format(', '.join(self.type_to_rust(arg, namespace, boxed) + for arg in ty.args)) + if boxed: + return "Box<'alloc, {}>".format(rty) + else: + return rty + + def actions(self): + # For each execution mode, add a corresponding function which + # implements various traits. The trait list is used for filtering which + # function is added in the generated code. + for mode, traits in self.parse_table.exec_modes.items(): + action_writer = RustActionWriter(self, mode, traits, 2) + start_at = self.shift_count + end_at = start_at + self.action_from_shift_count + assert len(self.states[self.shift_count:]) == self.action_count + traits_text = ' + '.join(map(self.type_to_rust, traits)) + table_holder_name = self.to_camel_case(mode) + table_holder_type = table_holder_name + "<'alloc, Handler>" + # As we do not have default associated types yet in Rust + # (rust-lang#29661), we have to peak from the parameter of the + # ParserTrait. + assert list(traits)[0].name == "ParserTrait" + arg_type = "TermValue<" + self.type_to_rust(list(traits)[0].args[1]) + ">" + self.write(0, "struct {} {{", table_holder_type) + self.write(1, "fns: [fn(&mut Handler) -> Result<'alloc, bool>; {}]", + self.action_from_shift_count) + self.write(0, "}") + self.write(0, "impl<'alloc, Handler> {}", table_holder_type) + self.write(0, "where") + self.write(1, "Handler: {}", traits_text) + self.write(0, "{") + self.write(1, "const TABLE : {} = {} {{", table_holder_type, table_holder_name) + self.write(2, "fns: [") + for state in self.states[start_at:end_at]: + assert state.arguments == 0 + self.write(3, "{}_{},", mode, state.index) + self.write(2, "],") + self.write(1, "};") + self.write(0, "}") + self.write(0, "") + self.write(0, + "pub fn {}<'alloc, Handler>(parser: &mut Handler, state: usize) " + "-> Result<'alloc, bool>", + mode) + self.write(0, "where") + self.write(1, "Handler: {}", traits_text) + self.write(0, "{") + self.write(1, "{}::<'alloc, Handler>::TABLE.fns[state - {}](parser)", + table_holder_name, start_at) + self.write(0, "}") + self.write(0, "") + for state in self.states[self.shift_count:]: + state_args = "" + for i in range(state.arguments): + state_args += ", v{}: {}".format(i, arg_type) + replay_args = ["v{}".format(i) for i in range(state.arguments)] + self.write(0, "#[inline]") + self.write(0, "#[allow(unused)]") + self.write(0, + "pub fn {}_{}<'alloc, Handler>(parser: &mut Handler{}) " + "-> Result<'alloc, bool>", + mode, state.index, state_args) + self.write(0, "where") + self.write(1, "Handler: {}", ' + '.join(map(self.type_to_rust, traits))) + self.write(0, "{") + action_writer.write_state_transitions(state, replay_args) + self.write(0, "}") + + def entry(self): + self.write(0, "#[derive(Clone, Copy)]") + self.write(0, "pub struct ParseTable<'a> {") + self.write(1, "pub shift_count: usize,") + self.write(1, "pub action_count: usize,") + self.write(1, "pub action_from_shift_count: usize,") + self.write(1, "pub shift_table: &'a [i64],") + self.write(1, "pub shift_width: usize,") + self.write(1, "pub error_codes: &'a [Option<ErrorCode>],") + self.write(0, "}") + self.write(0, "") + + self.write(0, "impl<'a> ParseTable<'a> {") + self.write(1, "pub fn check(&self) {") + self.write(2, "assert_eq!(") + self.write(3, "self.shift_table.len(),") + self.write(3, "(self.shift_count * self.shift_width) as usize") + self.write(2, ");") + self.write(1, "}") + self.write(0, "}") + self.write(0, "") + + self.write(0, "pub static TABLES: ParseTable<'static> = ParseTable {") + self.write(1, "shift_count: {},", self.shift_count) + self.write(1, "action_count: {},", self.action_count) + self.write(1, "action_from_shift_count: {},", self.action_from_shift_count) + self.write(1, "shift_table: &SHIFT,") + self.write(1, "shift_width: {},", len(self.terminals) + len(self.nonterminals)) + self.write(1, "error_codes: &STATE_TO_ERROR_CODE,") + self.write(0, "};") + self.write(0, "") + + for init_nt, index in self.init_state_map: + assert init_nt.args == () + self.write(0, "pub static START_STATE_{}: usize = {};", + self.nonterminal_to_snake(init_nt).upper(), index) + self.write(0, "") + + +def write_rust_parse_table(out, parse_table, handler_info): + if not handler_info: + print("WARNING: info.json is not provided", file=sys.stderr) + fallible_methods = [] + else: + with open(handler_info, "r") as json_file: + handler_info_json = json.load(json_file) + fallible_methods = handler_info_json["fallible-methods"] + + RustParserWriter(out, parse_table, fallible_methods).emit() |