diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/python/fluent.syntax/fluent | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/python/fluent.syntax/fluent')
8 files changed, 1766 insertions, 0 deletions
diff --git a/third_party/python/fluent.syntax/fluent/syntax/__init__.py b/third_party/python/fluent.syntax/fluent/syntax/__init__.py new file mode 100644 index 0000000000..1ff31745e6 --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/__init__.py @@ -0,0 +1,34 @@ +from typing import Any + +from . import ast +from .errors import ParseError +from .parser import FluentParser +from .serializer import FluentSerializer +from .stream import FluentParserStream +from .visitor import Transformer, Visitor + +__all__ = [ + 'FluentParser', + 'FluentParserStream', + 'FluentSerializer', + 'ParseError', + 'Transformer', + 'Visitor', + 'ast', + 'parse', + 'serialize' +] + + +def parse(source: str, **kwargs: Any) -> ast.Resource: + """Create an ast.Resource from a Fluent Syntax source. + """ + parser = FluentParser(**kwargs) + return parser.parse(source) + + +def serialize(resource: ast.Resource, **kwargs: Any) -> str: + """Serialize an ast.Resource to a unicode string. + """ + serializer = FluentSerializer(**kwargs) + return serializer.serialize(resource) diff --git a/third_party/python/fluent.syntax/fluent/syntax/ast.py b/third_party/python/fluent.syntax/fluent/syntax/ast.py new file mode 100644 index 0000000000..d2e4849079 --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/ast.py @@ -0,0 +1,376 @@ +import re +import sys +import json +from typing import Any, Callable, Dict, List, TypeVar, Union, cast + +Node = TypeVar('Node', bound='BaseNode') +ToJsonFn = Callable[[Dict[str, Any]], Any] + + +def to_json(value: Any, fn: Union[ToJsonFn, None] = None) -> Any: + if isinstance(value, BaseNode): + return value.to_json(fn) + if isinstance(value, list): + return list(to_json(item, fn) for item in value) + if isinstance(value, tuple): + return list(to_json(item, fn) for item in value) + else: + return value + + +def from_json(value: Any) -> Any: + if isinstance(value, dict): + cls = getattr(sys.modules[__name__], value['type']) + args = { + k: from_json(v) + for k, v in value.items() + if k != 'type' + } + return cls(**args) + if isinstance(value, list): + return list(map(from_json, value)) + else: + return value + + +def scalars_equal(node1: Any, node2: Any, ignored_fields: List[str]) -> bool: + """Compare two nodes which are not lists.""" + + if type(node1) != type(node2): + return False + + if isinstance(node1, BaseNode): + return node1.equals(node2, ignored_fields) + + return cast(bool, node1 == node2) + + +class BaseNode: + """Base class for all Fluent AST nodes. + + All productions described in the ASDL subclass BaseNode, including Span and + Annotation. Implements __str__, to_json and traverse. + """ + + def clone(self: Node) -> Node: + """Create a deep clone of the current node.""" + def visit(value: Any) -> Any: + """Clone node and its descendants.""" + if isinstance(value, BaseNode): + return value.clone() + if isinstance(value, list): + return [visit(child) for child in value] + if isinstance(value, tuple): + return tuple(visit(child) for child in value) + return value + + # Use all attributes found on the node as kwargs to the constructor. + return self.__class__( + **{name: visit(value) for name, value in vars(self).items()} + ) + + def equals(self, other: 'BaseNode', ignored_fields: List[str] = ['span']) -> bool: + """Compare two nodes. + + Nodes are deeply compared on a field by field basis. If possible, False + is returned early. When comparing attributes and variants in + SelectExpressions, the order doesn't matter. By default, spans are not + taken into account. + """ + + self_keys = set(vars(self).keys()) + other_keys = set(vars(other).keys()) + + if ignored_fields: + for key in ignored_fields: + self_keys.discard(key) + other_keys.discard(key) + + if self_keys != other_keys: + return False + + for key in self_keys: + field1 = getattr(self, key) + field2 = getattr(other, key) + + # List-typed nodes are compared item-by-item. When comparing + # attributes and variants, the order of items doesn't matter. + if isinstance(field1, list) and isinstance(field2, list): + if len(field1) != len(field2): + return False + + for elem1, elem2 in zip(field1, field2): + if not scalars_equal(elem1, elem2, ignored_fields): + return False + + elif not scalars_equal(field1, field2, ignored_fields): + return False + + return True + + def to_json(self, fn: Union[ToJsonFn, None] = None) -> Any: + obj = { + name: to_json(value, fn) + for name, value in vars(self).items() + } + obj.update( + {'type': self.__class__.__name__} + ) + return fn(obj) if fn else obj + + def __str__(self) -> str: + return json.dumps(self.to_json()) + + +class SyntaxNode(BaseNode): + """Base class for AST nodes which can have Spans.""" + + def __init__(self, span: Union['Span', None] = None, **kwargs: Any): + super().__init__(**kwargs) + self.span = span + + def add_span(self, start: int, end: int) -> None: + self.span = Span(start, end) + + +class Resource(SyntaxNode): + def __init__(self, body: Union[List['EntryType'], None] = None, **kwargs: Any): + super().__init__(**kwargs) + self.body = body or [] + + +class Entry(SyntaxNode): + """An abstract base class for useful elements of Resource.body.""" + + +class Message(Entry): + def __init__(self, + id: 'Identifier', + value: Union['Pattern', None] = None, + attributes: Union[List['Attribute'], None] = None, + comment: Union['Comment', None] = None, + **kwargs: Any): + super().__init__(**kwargs) + self.id = id + self.value = value + self.attributes = attributes or [] + self.comment = comment + + +class Term(Entry): + def __init__(self, id: 'Identifier', value: 'Pattern', attributes: Union[List['Attribute'], None] = None, + comment: Union['Comment', None] = None, **kwargs: Any): + super().__init__(**kwargs) + self.id = id + self.value = value + self.attributes = attributes or [] + self.comment = comment + + +class Pattern(SyntaxNode): + def __init__(self, elements: List[Union['TextElement', 'Placeable']], **kwargs: Any): + super().__init__(**kwargs) + self.elements = elements + + +class PatternElement(SyntaxNode): + """An abstract base class for elements of Patterns.""" + + +class TextElement(PatternElement): + def __init__(self, value: str, **kwargs: Any): + super().__init__(**kwargs) + self.value = value + + +class Placeable(PatternElement): + def __init__(self, + expression: Union['InlineExpression', 'Placeable', 'SelectExpression'], + **kwargs: Any): + super().__init__(**kwargs) + self.expression = expression + + +class Expression(SyntaxNode): + """An abstract base class for expressions.""" + + +class Literal(Expression): + """An abstract base class for literals.""" + + def __init__(self, value: str, **kwargs: Any): + super().__init__(**kwargs) + self.value = value + + def parse(self) -> Dict[str, Any]: + return {'value': self.value} + + +class StringLiteral(Literal): + def parse(self) -> Dict[str, str]: + def from_escape_sequence(matchobj: Any) -> str: + c, codepoint4, codepoint6 = matchobj.groups() + if c: + return cast(str, c) + codepoint = int(codepoint4 or codepoint6, 16) + if codepoint <= 0xD7FF or 0xE000 <= codepoint: + return chr(codepoint) + # Escape sequences reresenting surrogate code points are + # well-formed but invalid in Fluent. Replace them with U+FFFD + # REPLACEMENT CHARACTER. + return '�' + + value = re.sub( + r'\\(?:(\\|")|u([0-9a-fA-F]{4})|U([0-9a-fA-F]{6}))', + from_escape_sequence, + self.value + ) + return {'value': value} + + +class NumberLiteral(Literal): + def parse(self) -> Dict[str, Union[float, int]]: + value = float(self.value) + decimal_position = self.value.find('.') + precision = 0 + if decimal_position >= 0: + precision = len(self.value) - decimal_position - 1 + return { + 'value': value, + 'precision': precision + } + + +class MessageReference(Expression): + def __init__(self, id: 'Identifier', attribute: Union['Identifier', None] = None, **kwargs: Any): + super().__init__(**kwargs) + self.id = id + self.attribute = attribute + + +class TermReference(Expression): + def __init__(self, + id: 'Identifier', + attribute: Union['Identifier', None] = None, + arguments: Union['CallArguments', None] = None, + **kwargs: Any): + super().__init__(**kwargs) + self.id = id + self.attribute = attribute + self.arguments = arguments + + +class VariableReference(Expression): + def __init__(self, id: 'Identifier', **kwargs: Any): + super().__init__(**kwargs) + self.id = id + + +class FunctionReference(Expression): + def __init__(self, id: 'Identifier', arguments: 'CallArguments', **kwargs: Any): + super().__init__(**kwargs) + self.id = id + self.arguments = arguments + + +class SelectExpression(Expression): + def __init__(self, selector: 'InlineExpression', variants: List['Variant'], **kwargs: Any): + super().__init__(**kwargs) + self.selector = selector + self.variants = variants + + +class CallArguments(SyntaxNode): + def __init__(self, + positional: Union[List[Union['InlineExpression', Placeable]], None] = None, + named: Union[List['NamedArgument'], None] = None, + **kwargs: Any): + super().__init__(**kwargs) + self.positional = [] if positional is None else positional + self.named = [] if named is None else named + + +class Attribute(SyntaxNode): + def __init__(self, id: 'Identifier', value: Pattern, **kwargs: Any): + super().__init__(**kwargs) + self.id = id + self.value = value + + +class Variant(SyntaxNode): + def __init__(self, key: Union['Identifier', NumberLiteral], value: Pattern, default: bool = False, **kwargs: Any): + super().__init__(**kwargs) + self.key = key + self.value = value + self.default = default + + +class NamedArgument(SyntaxNode): + def __init__(self, name: 'Identifier', value: Union[NumberLiteral, StringLiteral], **kwargs: Any): + super().__init__(**kwargs) + self.name = name + self.value = value + + +class Identifier(SyntaxNode): + def __init__(self, name: str, **kwargs: Any): + super().__init__(**kwargs) + self.name = name + + +class BaseComment(Entry): + def __init__(self, content: Union[str, None] = None, **kwargs: Any): + super().__init__(**kwargs) + self.content = content + + +class Comment(BaseComment): + def __init__(self, content: Union[str, None] = None, **kwargs: Any): + super().__init__(content, **kwargs) + + +class GroupComment(BaseComment): + def __init__(self, content: Union[str, None] = None, **kwargs: Any): + super().__init__(content, **kwargs) + + +class ResourceComment(BaseComment): + def __init__(self, content: Union[str, None] = None, **kwargs: Any): + super().__init__(content, **kwargs) + + +class Junk(SyntaxNode): + def __init__(self, + content: Union[str, None] = None, + annotations: Union[List['Annotation'], None] = None, + **kwargs: Any): + super().__init__(**kwargs) + self.content = content + self.annotations = annotations or [] + + def add_annotation(self, annot: 'Annotation') -> None: + self.annotations.append(annot) + + +class Span(BaseNode): + def __init__(self, start: int, end: int, **kwargs: Any): + super().__init__(**kwargs) + self.start = start + self.end = end + + +class Annotation(SyntaxNode): + def __init__(self, + code: str, + arguments: Union[List[Any], None] = None, + message: Union[str, None] = None, + **kwargs: Any): + super().__init__(**kwargs) + self.code = code + self.arguments = arguments or [] + self.message = message + + +EntryType = Union[Message, Term, Comment, GroupComment, ResourceComment, Junk] +InlineExpression = Union[NumberLiteral, StringLiteral, MessageReference, + TermReference, VariableReference, FunctionReference] diff --git a/third_party/python/fluent.syntax/fluent/syntax/errors.py b/third_party/python/fluent.syntax/fluent/syntax/errors.py new file mode 100644 index 0000000000..010374828f --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/errors.py @@ -0,0 +1,70 @@ +from typing import Tuple, Union + + +class ParseError(Exception): + def __init__(self, code: str, *args: Union[str, None]): + self.code = code + self.args = args + self.message = get_error_message(code, args) + + +def get_error_message(code: str, args: Tuple[Union[str, None], ...]) -> str: + if code == 'E00001': + return 'Generic error' + if code == 'E0002': + return 'Expected an entry start' + if code == 'E0003': + return 'Expected token: "{}"'.format(args[0]) + if code == 'E0004': + return 'Expected a character from range: "{}"'.format(args[0]) + if code == 'E0005': + msg = 'Expected message "{}" to have a value or attributes' + return msg.format(args[0]) + if code == 'E0006': + msg = 'Expected term "-{}" to have a value' + return msg.format(args[0]) + if code == 'E0007': + return 'Keyword cannot end with a whitespace' + if code == 'E0008': + return 'The callee has to be an upper-case identifier or a term' + if code == 'E0009': + return 'The argument name has to be a simple identifier' + if code == 'E0010': + return 'Expected one of the variants to be marked as default (*)' + if code == 'E0011': + return 'Expected at least one variant after "->"' + if code == 'E0012': + return 'Expected value' + if code == 'E0013': + return 'Expected variant key' + if code == 'E0014': + return 'Expected literal' + if code == 'E0015': + return 'Only one variant can be marked as default (*)' + if code == 'E0016': + return 'Message references cannot be used as selectors' + if code == 'E0017': + return 'Terms cannot be used as selectors' + if code == 'E0018': + return 'Attributes of messages cannot be used as selectors' + if code == 'E0019': + return 'Attributes of terms cannot be used as placeables' + if code == 'E0020': + return 'Unterminated string expression' + if code == 'E0021': + return 'Positional arguments must not follow named arguments' + if code == 'E0022': + return 'Named arguments must be unique' + if code == 'E0024': + return 'Cannot access variants of a message.' + if code == 'E0025': + return 'Unknown escape sequence: \\{}.'.format(args[0]) + if code == 'E0026': + return 'Invalid Unicode escape sequence: {}.'.format(args[0]) + if code == 'E0027': + return 'Unbalanced closing brace in TextElement.' + if code == 'E0028': + return 'Expected an inline expression' + if code == 'E0029': + return 'Expected simple expression as selector' + return code diff --git a/third_party/python/fluent.syntax/fluent/syntax/parser.py b/third_party/python/fluent.syntax/fluent/syntax/parser.py new file mode 100644 index 0000000000..87075409f1 --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/parser.py @@ -0,0 +1,701 @@ +import re +from typing import Any, Callable, List, Set, TypeVar, Union, cast +from . import ast +from .stream import EOL, FluentParserStream +from .errors import ParseError + +R = TypeVar("R", bound=ast.SyntaxNode) + + +def with_span(fn: Callable[..., R]) -> Callable[..., R]: + def decorated(self: 'FluentParser', ps: FluentParserStream, *args: Any, **kwargs: Any) -> Any: + if not self.with_spans: + return fn(self, ps, *args, **kwargs) + + start = ps.index + node = fn(self, ps, *args, **kwargs) + + # Don't re-add the span if the node already has it. This may happen + # when one decorated function calls another decorated function. + if node.span is not None: + return node + + end = ps.index + node.add_span(start, end) + return node + + return decorated + + +class FluentParser: + """This class is used to parse Fluent source content. + + ``with_spans`` enables source information in the form of + :class:`.ast.Span` objects for each :class:`.ast.SyntaxNode`. + """ + + def __init__(self, with_spans: bool = True): + self.with_spans = with_spans + + def parse(self, source: str) -> ast.Resource: + """Create a :class:`.ast.Resource` from a Fluent source. + """ + ps = FluentParserStream(source) + ps.skip_blank_block() + + entries: List[ast.EntryType] = [] + last_comment = None + + while ps.current_char: + entry = self.get_entry_or_junk(ps) + blank_lines = ps.skip_blank_block() + + # Regular Comments require special logic. Comments may be attached + # to Messages or Terms if they are followed immediately by them. + # However they should parse as standalone when they're followed by + # Junk. Consequently, we only attach Comments once we know that the + # Message or the Term parsed successfully. + if isinstance(entry, ast.Comment) and len(blank_lines) == 0 \ + and ps.current_char: + # Stash the comment and decide what to do with it + # in the next pass. + last_comment = entry + continue + + if last_comment is not None: + if isinstance(entry, (ast.Message, ast.Term)): + entry.comment = last_comment + if self.with_spans: + cast(ast.Span, entry.span).start = cast(ast.Span, entry.comment.span).start + else: + entries.append(last_comment) + # In either case, the stashed comment has been dealt with; + # clear it. + last_comment = None + + entries.append(entry) + + res = ast.Resource(entries) + + if self.with_spans: + res.add_span(0, ps.index) + + return res + + def parse_entry(self, source: str) -> ast.EntryType: + """Parse the first :class:`.ast.Entry` in source. + + Skip all encountered comments and start parsing at the first :class:`.ast.Message` + or :class:`.ast.Term` start. Return :class:`.ast.Junk` if the parsing is not successful. + + Preceding comments are ignored unless they contain syntax errors + themselves, in which case :class:`.ast.Junk` for the invalid comment is returned. + """ + ps = FluentParserStream(source) + ps.skip_blank_block() + + while ps.current_char == '#': + skipped = self.get_entry_or_junk(ps) + if isinstance(skipped, ast.Junk): + # Don't skip Junk comments. + return skipped + ps.skip_blank_block() + + return self.get_entry_or_junk(ps) + + def get_entry_or_junk(self, ps: FluentParserStream) -> ast.EntryType: + entry_start_pos = ps.index + + try: + entry = self.get_entry(ps) + ps.expect_line_end() + return entry + except ParseError as err: + error_index = ps.index + ps.skip_to_next_entry_start(entry_start_pos) + next_entry_start = ps.index + if next_entry_start < error_index: + # The position of the error must be inside of the Junk's span. + error_index = next_entry_start + + # Create a Junk instance + slice = ps.string[entry_start_pos:next_entry_start] + junk = ast.Junk(slice) + if self.with_spans: + junk.add_span(entry_start_pos, next_entry_start) + annot = ast.Annotation(err.code, list(err.args) if err.args else None, err.message) + annot.add_span(error_index, error_index) + junk.add_annotation(annot) + return junk + + def get_entry(self, ps: FluentParserStream) -> ast.EntryType: + if ps.current_char == '#': + return self.get_comment(ps) + + if ps.current_char == '-': + return self.get_term(ps) + + if ps.is_identifier_start(): + return self.get_message(ps) + + raise ParseError('E0002') + + @with_span + def get_comment(self, ps: FluentParserStream) -> Union[ast.Comment, ast.GroupComment, ast.ResourceComment]: + # 0 - comment + # 1 - group comment + # 2 - resource comment + level = -1 + content = '' + + while True: + i = -1 + while ps.current_char == '#' \ + and (i < (2 if level == -1 else level)): + ps.next() + i += 1 + + if level == -1: + level = i + + if ps.current_char != EOL: + ps.expect_char(' ') + ch = ps.take_char(lambda x: x != EOL) + while ch: + content += ch + ch = ps.take_char(lambda x: x != EOL) + + if ps.is_next_line_comment(level=level): + content += cast(str, ps.current_char) + ps.next() + else: + break + + if level == 0: + return ast.Comment(content) + elif level == 1: + return ast.GroupComment(content) + elif level == 2: + return ast.ResourceComment(content) + + # never happens if ps.current_char == '#' when called + return cast(ast.Comment, None) + + @with_span + def get_message(self, ps: FluentParserStream) -> ast.Message: + id = self.get_identifier(ps) + ps.skip_blank_inline() + ps.expect_char('=') + + value = self.maybe_get_pattern(ps) + attrs = self.get_attributes(ps) + + if value is None and len(attrs) == 0: + raise ParseError('E0005', id.name) + + return ast.Message(id, value, attrs) + + @with_span + def get_term(self, ps: FluentParserStream) -> ast.Term: + ps.expect_char('-') + id = self.get_identifier(ps) + + ps.skip_blank_inline() + ps.expect_char('=') + + value = self.maybe_get_pattern(ps) + if value is None: + raise ParseError('E0006', id.name) + + attrs = self.get_attributes(ps) + return ast.Term(id, value, attrs) + + @with_span + def get_attribute(self, ps: FluentParserStream) -> ast.Attribute: + ps.expect_char('.') + + key = self.get_identifier(ps) + + ps.skip_blank_inline() + ps.expect_char('=') + + value = self.maybe_get_pattern(ps) + if value is None: + raise ParseError('E0012') + + return ast.Attribute(key, value) + + def get_attributes(self, ps: FluentParserStream) -> List[ast.Attribute]: + attrs: List[ast.Attribute] = [] + ps.peek_blank() + + while ps.is_attribute_start(): + ps.skip_to_peek() + attr = self.get_attribute(ps) + attrs.append(attr) + ps.peek_blank() + + return attrs + + @with_span + def get_identifier(self, ps: FluentParserStream) -> ast.Identifier: + name = ps.take_id_start() + if name is None: + raise ParseError('E0004', 'a-zA-Z') + + ch = ps.take_id_char() + while ch: + name += ch + ch = ps.take_id_char() + + return ast.Identifier(name) + + def get_variant_key(self, ps: FluentParserStream) -> Union[ast.Identifier, ast.NumberLiteral]: + ch = ps.current_char + + if ch is None: + raise ParseError('E0013') + + cc = ord(ch) + if ((cc >= 48 and cc <= 57) or cc == 45): # 0-9, - + return self.get_number(ps) + + return self.get_identifier(ps) + + @with_span + def get_variant(self, ps: FluentParserStream, has_default: bool) -> ast.Variant: + default_index = False + + if ps.current_char == '*': + if has_default: + raise ParseError('E0015') + ps.next() + default_index = True + + ps.expect_char('[') + ps.skip_blank() + + key = self.get_variant_key(ps) + + ps.skip_blank() + ps.expect_char(']') + + value = self.maybe_get_pattern(ps) + if value is None: + raise ParseError('E0012') + + return ast.Variant(key, value, default_index) + + def get_variants(self, ps: FluentParserStream) -> List[ast.Variant]: + variants: List[ast.Variant] = [] + has_default = False + + ps.skip_blank() + while ps.is_variant_start(): + variant = self.get_variant(ps, has_default) + + if variant.default: + has_default = True + + variants.append(variant) + ps.expect_line_end() + ps.skip_blank() + + if len(variants) == 0: + raise ParseError('E0011') + + if not has_default: + raise ParseError('E0010') + + return variants + + def get_digits(self, ps: FluentParserStream) -> str: + num = '' + + ch = ps.take_digit() + while ch: + num += ch + ch = ps.take_digit() + + if len(num) == 0: + raise ParseError('E0004', '0-9') + + return num + + @with_span + def get_number(self, ps: FluentParserStream) -> ast.NumberLiteral: + num = '' + + if ps.current_char == '-': + num += '-' + ps.next() + + num += self.get_digits(ps) + + if ps.current_char == '.': + num += '.' + ps.next() + num += self.get_digits(ps) + + return ast.NumberLiteral(num) + + def maybe_get_pattern(self, ps: FluentParserStream) -> Union[ast.Pattern, None]: + '''Parse an inline or a block Pattern, or None + + maybe_get_pattern distinguishes between patterns which start on the + same line as the indentifier (aka inline singleline patterns and inline + multiline patterns), and patterns which start on a new line (aka block + patterns). The distinction is important for the dedentation logic: the + indent of the first line of a block pattern must be taken into account + when calculating the maximum common indent. + ''' + ps.peek_blank_inline() + if ps.is_value_start(): + ps.skip_to_peek() + return self.get_pattern(ps, is_block=False) + + ps.peek_blank_block() + if ps.is_value_continuation(): + ps.skip_to_peek() + return self.get_pattern(ps, is_block=True) + + return None + + @with_span + def get_pattern(self, ps: FluentParserStream, is_block: bool) -> ast.Pattern: + elements: List[Any] = [] + if is_block: + # A block pattern is a pattern which starts on a new line. Measure + # the indent of this first line for the dedentation logic. + blank_start = ps.index + first_indent = ps.skip_blank_inline() + elements.append(self.Indent(first_indent, blank_start, ps.index)) + common_indent_length = len(first_indent) + else: + # Should get fixed by the subsequent min() operation + common_indent_length = cast(int, float('infinity')) + + while ps.current_char: + if ps.current_char == EOL: + blank_start = ps.index + blank_lines = ps.peek_blank_block() + if ps.is_value_continuation(): + ps.skip_to_peek() + indent = ps.skip_blank_inline() + common_indent_length = min(common_indent_length, len(indent)) + elements.append(self.Indent(blank_lines + indent, blank_start, ps.index)) + continue + + # The end condition for get_pattern's while loop is a newline + # which is not followed by a valid pattern continuation. + ps.reset_peek() + break + + if ps.current_char == '}': + raise ParseError('E0027') + + element: Union[ast.TextElement, ast.Placeable] + if ps.current_char == '{': + element = self.get_placeable(ps) + else: + element = self.get_text_element(ps) + + elements.append(element) + + dedented = self.dedent(elements, common_indent_length) + return ast.Pattern(dedented) + + class Indent(ast.SyntaxNode): + def __init__(self, value: str, start: int, end: int): + super(FluentParser.Indent, self).__init__() + self.value = value + self.add_span(start, end) + + def dedent(self, + elements: List[Union[ast.TextElement, ast.Placeable, Indent]], + common_indent: int + ) -> List[Union[ast.TextElement, ast.Placeable]]: + '''Dedent a list of elements by removing the maximum common indent from + the beginning of text lines. The common indent is calculated in + get_pattern. + ''' + trimmed: List[Union[ast.TextElement, ast.Placeable]] = [] + + for element in elements: + if isinstance(element, ast.Placeable): + trimmed.append(element) + continue + + if isinstance(element, self.Indent): + # Strip the common indent. + element.value = element.value[:len(element.value) - common_indent] + if len(element.value) == 0: + continue + + prev = trimmed[-1] if len(trimmed) > 0 else None + if isinstance(prev, ast.TextElement): + # Join adjacent TextElements by replacing them with their sum. + sum = ast.TextElement(prev.value + element.value) + if self.with_spans: + sum.add_span(cast(ast.Span, prev.span).start, cast(ast.Span, element.span).end) + trimmed[-1] = sum + continue + + if isinstance(element, self.Indent): + # If the indent hasn't been merged into a preceding + # TextElements, convert it into a new TextElement. + text_element = ast.TextElement(element.value) + if self.with_spans: + text_element.add_span(cast(ast.Span, element.span).start, cast(ast.Span, element.span).end) + element = text_element + + trimmed.append(element) + + # Trim trailing whitespace from the Pattern. + last_element = trimmed[-1] if len(trimmed) > 0 else None + if isinstance(last_element, ast.TextElement): + last_element.value = last_element.value.rstrip(' \n\r') + if last_element.value == "": + trimmed.pop() + + return trimmed + + @with_span + def get_text_element(self, ps: FluentParserStream) -> ast.TextElement: + buf = '' + + while ps.current_char: + ch = ps.current_char + + if ch == '{' or ch == '}': + return ast.TextElement(buf) + + if ch == EOL: + return ast.TextElement(buf) + + buf += ch + ps.next() + + return ast.TextElement(buf) + + def get_escape_sequence(self, ps: FluentParserStream) -> str: + next = ps.current_char + + if next == '\\' or next == '"': + ps.next() + return f'\\{next}' + + if next == 'u': + return self.get_unicode_escape_sequence(ps, next, 4) + + if next == 'U': + return self.get_unicode_escape_sequence(ps, next, 6) + + raise ParseError('E0025', next) + + def get_unicode_escape_sequence(self, ps: FluentParserStream, u: str, digits: int) -> str: + ps.expect_char(u) + sequence = '' + for _ in range(digits): + ch = ps.take_hex_digit() + if not ch: + raise ParseError('E0026', f'\\{u}{sequence}{ps.current_char}') + sequence += ch + + return f'\\{u}{sequence}' + + @with_span + def get_placeable(self, ps: FluentParserStream) -> ast.Placeable: + ps.expect_char('{') + ps.skip_blank() + expression = self.get_expression(ps) + ps.expect_char('}') + return ast.Placeable(expression) + + @with_span + def get_expression(self, ps: FluentParserStream) -> Union[ast.InlineExpression, + ast.Placeable, + ast.SelectExpression]: + selector = self.get_inline_expression(ps) + + ps.skip_blank() + + if ps.current_char == '-': + if ps.peek() != '>': + ps.reset_peek() + return selector + + if isinstance(selector, ast.MessageReference): + if selector.attribute is None: + raise ParseError('E0016') + else: + raise ParseError('E0018') + + elif ( + isinstance(selector, ast.TermReference) + ): + if selector.attribute is None: + raise ParseError('E0017') + elif not ( + isinstance(selector, ( + ast.StringLiteral, + ast.NumberLiteral, + ast.VariableReference, + ast.FunctionReference, + )) + ): + raise ParseError('E0029') + + ps.next() + ps.next() + + ps.skip_blank_inline() + ps.expect_line_end() + + variants = self.get_variants(ps) + return ast.SelectExpression(selector, variants) + + if ( + isinstance(selector, ast.TermReference) + and selector.attribute is not None + ): + raise ParseError('E0019') + + return selector + + @with_span + def get_inline_expression(self, ps: FluentParserStream) -> Union[ast.InlineExpression, ast.Placeable]: + if ps.current_char == '{': + return self.get_placeable(ps) + + if ps.is_number_start(): + return self.get_number(ps) + + if ps.current_char == '"': + return self.get_string(ps) + + if ps.current_char == '$': + ps.next() + id = self.get_identifier(ps) + return ast.VariableReference(id) + + if ps.current_char == '-': + ps.next() + id = self.get_identifier(ps) + attribute = None + if ps.current_char == '.': + ps.next() + attribute = self.get_identifier(ps) + arguments = None + ps.peek_blank() + if ps.current_peek == '(': + ps.skip_to_peek() + arguments = self.get_call_arguments(ps) + return ast.TermReference(id, attribute, arguments) + + if ps.is_identifier_start(): + id = self.get_identifier(ps) + ps.peek_blank() + + if ps.current_peek == '(': + # It's a Function. Ensure it's all upper-case. + if not re.match('^[A-Z][A-Z0-9_-]*$', id.name): + raise ParseError('E0008') + ps.skip_to_peek() + args = self.get_call_arguments(ps) + return ast.FunctionReference(id, args) + + attribute = None + if ps.current_char == '.': + ps.next() + attribute = self.get_identifier(ps) + + return ast.MessageReference(id, attribute) + + raise ParseError('E0028') + + @with_span + def get_call_argument(self, + ps: FluentParserStream + ) -> Union[ast.InlineExpression, ast.NamedArgument, ast.Placeable]: + exp = self.get_inline_expression(ps) + + ps.skip_blank() + + if ps.current_char != ':': + return exp + + if isinstance(exp, ast.MessageReference) and exp.attribute is None: + ps.next() + ps.skip_blank() + + value = self.get_literal(ps) + return ast.NamedArgument(exp.id, value) + + raise ParseError('E0009') + + @with_span + def get_call_arguments(self, ps: FluentParserStream) -> ast.CallArguments: + positional: List[Union[ast.InlineExpression, ast.Placeable]] = [] + named: List[ast.NamedArgument] = [] + argument_names: Set[str] = set() + + ps.expect_char('(') + ps.skip_blank() + + while True: + if ps.current_char == ')': + break + + arg = self.get_call_argument(ps) + if isinstance(arg, ast.NamedArgument): + if arg.name.name in argument_names: + raise ParseError('E0022') + named.append(arg) + argument_names.add(arg.name.name) + elif len(argument_names) > 0: + raise ParseError('E0021') + else: + positional.append(arg) + + ps.skip_blank() + + if ps.current_char == ',': + ps.next() + ps.skip_blank() + continue + + break + + ps.expect_char(')') + return ast.CallArguments(positional, named) + + @with_span + def get_string(self, ps: FluentParserStream) -> ast.StringLiteral: + value = '' + + ps.expect_char('"') + + while True: + ch = ps.take_char(lambda x: x != '"' and x != EOL) + if not ch: + break + if ch == '\\': + value += self.get_escape_sequence(ps) + else: + value += ch + + if ps.current_char == EOL: + raise ParseError('E0020') + + ps.expect_char('"') + + return ast.StringLiteral(value) + + @with_span + def get_literal(self, ps: FluentParserStream) -> Union[ast.NumberLiteral, ast.StringLiteral]: + if ps.is_number_start(): + return self.get_number(ps) + if ps.current_char == '"': + return self.get_string(ps) + raise ParseError('E0014') diff --git a/third_party/python/fluent.syntax/fluent/syntax/py.typed b/third_party/python/fluent.syntax/fluent/syntax/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/py.typed diff --git a/third_party/python/fluent.syntax/fluent/syntax/serializer.py b/third_party/python/fluent.syntax/fluent/syntax/serializer.py new file mode 100644 index 0000000000..68ea89b3d3 --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/serializer.py @@ -0,0 +1,237 @@ +from typing import List, Union +from . import ast + + +def indent_except_first_line(content: str) -> str: + return " ".join( + content.splitlines(True) + ) + + +def includes_new_line(elem: Union[ast.TextElement, ast.Placeable]) -> bool: + return isinstance(elem, ast.TextElement) and "\n" in elem.value + + +def is_select_expr(elem: Union[ast.TextElement, ast.Placeable]) -> bool: + return ( + isinstance(elem, ast.Placeable) and + isinstance(elem.expression, ast.SelectExpression)) + + +def should_start_on_new_line(pattern: ast.Pattern) -> bool: + is_multiline = any(is_select_expr(elem) for elem in pattern.elements) \ + or any(includes_new_line(elem) for elem in pattern.elements) + + if is_multiline: + first_element = pattern.elements[0] + if isinstance(first_element, ast.TextElement): + first_char = first_element.value[0] + if first_char in ("[", ".", "*"): + return False + return True + return False + + +class FluentSerializer: + """FluentSerializer converts :class:`.ast.SyntaxNode` objects to unicode strings. + + `with_junk` controls if parse errors are written back or not. + """ + HAS_ENTRIES = 1 + + def __init__(self, with_junk: bool = False): + self.with_junk = with_junk + + def serialize(self, resource: ast.Resource) -> str: + "Serialize a :class:`.ast.Resource` to a string." + if not isinstance(resource, ast.Resource): + raise Exception('Unknown resource type: {}'.format(type(resource))) + + state = 0 + + parts: List[str] = [] + for entry in resource.body: + if not isinstance(entry, ast.Junk) or self.with_junk: + parts.append(self.serialize_entry(entry, state)) + if not state & self.HAS_ENTRIES: + state |= self.HAS_ENTRIES + + return "".join(parts) + + def serialize_entry(self, entry: ast.EntryType, state: int = 0) -> str: + "Serialize an :class:`.ast.Entry` to a string." + if isinstance(entry, ast.Message): + return serialize_message(entry) + if isinstance(entry, ast.Term): + return serialize_term(entry) + if isinstance(entry, ast.Comment): + if state & self.HAS_ENTRIES: + return "\n{}\n".format(serialize_comment(entry, "#")) + return "{}\n".format(serialize_comment(entry, "#")) + if isinstance(entry, ast.GroupComment): + if state & self.HAS_ENTRIES: + return "\n{}\n".format(serialize_comment(entry, "##")) + return "{}\n".format(serialize_comment(entry, "##")) + if isinstance(entry, ast.ResourceComment): + if state & self.HAS_ENTRIES: + return "\n{}\n".format(serialize_comment(entry, "###")) + return "{}\n".format(serialize_comment(entry, "###")) + if isinstance(entry, ast.Junk): + return serialize_junk(entry) + raise Exception('Unknown entry type: {}'.format(type(entry))) + + +def serialize_comment(comment: Union[ast.Comment, ast.GroupComment, ast.ResourceComment], prefix: str = "#") -> str: + if not comment.content: + return f'{prefix}\n' + + prefixed = "\n".join([ + prefix if len(line) == 0 else f"{prefix} {line}" + for line in comment.content.split("\n") + ]) + # Add the trailing line break. + return f'{prefixed}\n' + + +def serialize_junk(junk: ast.Junk) -> str: + return junk.content or '' + + +def serialize_message(message: ast.Message) -> str: + parts: List[str] = [] + + if message.comment: + parts.append(serialize_comment(message.comment)) + + parts.append(f"{message.id.name} =") + + if message.value: + parts.append(serialize_pattern(message.value)) + + if message.attributes: + for attribute in message.attributes: + parts.append(serialize_attribute(attribute)) + + parts.append("\n") + return ''.join(parts) + + +def serialize_term(term: ast.Term) -> str: + parts: List[str] = [] + + if term.comment: + parts.append(serialize_comment(term.comment)) + + parts.append(f"-{term.id.name} =") + parts.append(serialize_pattern(term.value)) + + if term.attributes: + for attribute in term.attributes: + parts.append(serialize_attribute(attribute)) + + parts.append("\n") + return ''.join(parts) + + +def serialize_attribute(attribute: ast.Attribute) -> str: + return "\n .{} ={}".format( + attribute.id.name, + indent_except_first_line(serialize_pattern(attribute.value)) + ) + + +def serialize_pattern(pattern: ast.Pattern) -> str: + content = "".join(serialize_element(elem) for elem in pattern.elements) + content = indent_except_first_line(content) + + if should_start_on_new_line(pattern): + return f'\n {content}' + + return f' {content}' + + +def serialize_element(element: ast.PatternElement) -> str: + if isinstance(element, ast.TextElement): + return element.value + if isinstance(element, ast.Placeable): + return serialize_placeable(element) + raise Exception('Unknown element type: {}'.format(type(element))) + + +def serialize_placeable(placeable: ast.Placeable) -> str: + expr = placeable.expression + if isinstance(expr, ast.Placeable): + return "{{{}}}".format(serialize_placeable(expr)) + if isinstance(expr, ast.SelectExpression): + # Special-case select expressions to control the withespace around the + # opening and the closing brace. + return "{{ {}}}".format(serialize_expression(expr)) + if isinstance(expr, ast.Expression): + return "{{ {} }}".format(serialize_expression(expr)) + raise Exception('Unknown expression type: {}'.format(type(expr))) + + +def serialize_expression(expression: Union[ast.Expression, ast.Placeable]) -> str: + if isinstance(expression, ast.StringLiteral): + return f'"{expression.value}"' + if isinstance(expression, ast.NumberLiteral): + return expression.value + if isinstance(expression, ast.VariableReference): + return f"${expression.id.name}" + if isinstance(expression, ast.TermReference): + out = f"-{expression.id.name}" + if expression.attribute is not None: + out += f".{expression.attribute.name}" + if expression.arguments is not None: + out += serialize_call_arguments(expression.arguments) + return out + if isinstance(expression, ast.MessageReference): + out = expression.id.name + if expression.attribute is not None: + out += f".{expression.attribute.name}" + return out + if isinstance(expression, ast.FunctionReference): + args = serialize_call_arguments(expression.arguments) + return f"{expression.id.name}{args}" + if isinstance(expression, ast.SelectExpression): + out = "{} ->".format( + serialize_expression(expression.selector)) + for variant in expression.variants: + out += serialize_variant(variant) + return f"{out}\n" + if isinstance(expression, ast.Placeable): + return serialize_placeable(expression) + raise Exception('Unknown expression type: {}'.format(type(expression))) + + +def serialize_variant(variant: ast.Variant) -> str: + return "\n{}[{}]{}".format( + " *" if variant.default else " ", + serialize_variant_key(variant.key), + indent_except_first_line(serialize_pattern(variant.value)) + ) + + +def serialize_call_arguments(expr: ast.CallArguments) -> str: + positional = ", ".join( + serialize_expression(arg) for arg in expr.positional) + named = ", ".join( + serialize_named_argument(arg) for arg in expr.named) + if len(expr.positional) > 0 and len(expr.named) > 0: + return f'({positional}, {named})' + return '({})'.format(positional or named) + + +def serialize_named_argument(arg: ast.NamedArgument) -> str: + return "{}: {}".format( + arg.name.name, + serialize_expression(arg.value) + ) + + +def serialize_variant_key(key: Union[ast.Identifier, ast.NumberLiteral]) -> str: + if isinstance(key, ast.Identifier): + return key.name + if isinstance(key, ast.NumberLiteral): + return key.value + raise Exception('Unknown variant key type: {}'.format(type(key))) diff --git a/third_party/python/fluent.syntax/fluent/syntax/stream.py b/third_party/python/fluent.syntax/fluent/syntax/stream.py new file mode 100644 index 0000000000..150ac933ca --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/stream.py @@ -0,0 +1,283 @@ +from typing import Callable, Union +from typing_extensions import Literal +from .errors import ParseError + + +class ParserStream: + def __init__(self, string: str): + self.string = string + self.index = 0 + self.peek_offset = 0 + + def get(self, offset: int) -> Union[str, None]: + try: + return self.string[offset] + except IndexError: + return None + + def char_at(self, offset: int) -> Union[str, None]: + # When the cursor is at CRLF, return LF but don't move the cursor. The + # cursor still points to the EOL position, which in this case is the + # beginning of the compound CRLF sequence. This ensures slices of + # [inclusive, exclusive) continue to work properly. + if self.get(offset) == '\r' \ + and self.get(offset + 1) == '\n': + return '\n' + + return self.get(offset) + + @property + def current_char(self) -> Union[str, None]: + return self.char_at(self.index) + + @property + def current_peek(self) -> Union[str, None]: + return self.char_at(self.index + self.peek_offset) + + def next(self) -> Union[str, None]: + self.peek_offset = 0 + # Skip over CRLF as if it was a single character. + if self.get(self.index) == '\r' \ + and self.get(self.index + 1) == '\n': + self.index += 1 + self.index += 1 + return self.get(self.index) + + def peek(self) -> Union[str, None]: + # Skip over CRLF as if it was a single character. + if self.get(self.index + self.peek_offset) == '\r' \ + and self.get(self.index + self.peek_offset + 1) == '\n': + self.peek_offset += 1 + self.peek_offset += 1 + return self.get(self.index + self.peek_offset) + + def reset_peek(self, offset: int = 0) -> None: + self.peek_offset = offset + + def skip_to_peek(self) -> None: + self.index += self.peek_offset + self.peek_offset = 0 + + +EOL = '\n' +EOF = None +SPECIAL_LINE_START_CHARS = ('}', '.', '[', '*') + + +class FluentParserStream(ParserStream): + + def peek_blank_inline(self) -> str: + start = self.index + self.peek_offset + while self.current_peek == ' ': + self.peek() + return self.string[start:self.index + self.peek_offset] + + def skip_blank_inline(self) -> str: + blank = self.peek_blank_inline() + self.skip_to_peek() + return blank + + def peek_blank_block(self) -> str: + blank = "" + while True: + line_start = self.peek_offset + self.peek_blank_inline() + + if self.current_peek == EOL: + blank += EOL + self.peek() + continue + + if self.current_peek is EOF: + # Treat the blank line at EOF as a blank block. + return blank + + # Any other char; reset to column 1 on this line. + self.reset_peek(line_start) + return blank + + def skip_blank_block(self) -> str: + blank = self.peek_blank_block() + self.skip_to_peek() + return blank + + def peek_blank(self) -> None: + while self.current_peek in (" ", EOL): + self.peek() + + def skip_blank(self) -> None: + self.peek_blank() + self.skip_to_peek() + + def expect_char(self, ch: str) -> Literal[True]: + if self.current_char == ch: + self.next() + return True + + raise ParseError('E0003', ch) + + def expect_line_end(self) -> Literal[True]: + if self.current_char is EOF: + # EOF is a valid line end in Fluent. + return True + + if self.current_char == EOL: + self.next() + return True + + # Unicode Character 'SYMBOL FOR NEWLINE' (U+2424) + raise ParseError('E0003', '\u2424') + + def take_char(self, f: Callable[[str], bool]) -> Union[str, Literal[False], None]: + ch = self.current_char + if ch is None: + return EOF + if f(ch): + self.next() + return ch + return False + + def is_char_id_start(self, ch: Union[str, None]) -> bool: + if ch is None: + return False + + cc = ord(ch) + return (cc >= 97 and cc <= 122) or \ + (cc >= 65 and cc <= 90) + + def is_identifier_start(self) -> bool: + return self.is_char_id_start(self.current_peek) + + def is_number_start(self) -> bool: + ch = self.peek() if self.current_char == '-' else self.current_char + if ch is None: + self.reset_peek() + return False + + cc = ord(ch) + is_digit = cc >= 48 and cc <= 57 + self.reset_peek() + return is_digit + + def is_char_pattern_continuation(self, ch: Union[str, None]) -> bool: + if ch is EOF: + return False + + return ch not in SPECIAL_LINE_START_CHARS + + def is_value_start(self) -> bool: + # Inline Patterns may start with any char. + return self.current_peek is not EOF and self.current_peek != EOL + + def is_value_continuation(self) -> bool: + column1 = self.peek_offset + self.peek_blank_inline() + + if self.current_peek == '{': + self.reset_peek(column1) + return True + + if self.peek_offset - column1 == 0: + return False + + if self.is_char_pattern_continuation(self.current_peek): + self.reset_peek(column1) + return True + + return False + + # -1 - any + # 0 - comment + # 1 - group comment + # 2 - resource comment + def is_next_line_comment(self, level: int = -1) -> bool: + if self.current_peek != EOL: + return False + + i = 0 + + while (i <= level or (level == -1 and i < 3)): + if self.peek() != '#': + if i <= level and level != -1: + self.reset_peek() + return False + break + i += 1 + + # The first char after #, ## or ###. + if self.peek() in (' ', EOL): + self.reset_peek() + return True + + self.reset_peek() + return False + + def is_variant_start(self) -> bool: + current_peek_offset = self.peek_offset + if self.current_peek == '*': + self.peek() + if self.current_peek == '[' and self.peek() != '[': + self.reset_peek(current_peek_offset) + return True + + self.reset_peek(current_peek_offset) + return False + + def is_attribute_start(self) -> bool: + return self.current_peek == '.' + + def skip_to_next_entry_start(self, junk_start: int) -> None: + last_newline = self.string.rfind(EOL, 0, self.index) + if junk_start < last_newline: + # Last seen newline is _after_ the junk start. It's safe to rewind + # without the risk of resuming at the same broken entry. + self.index = last_newline + + while self.current_char: + # We're only interested in beginnings of line. + if self.current_char != EOL: + self.next() + continue + + # Break if the first char in this line looks like an entry start. + first = self.next() + if self.is_char_id_start(first) or first == '-' or first == '#': + break + + # Syntax 0.4 compatibility + peek = self.peek() + self.reset_peek() + if (first, peek) == ('/', '/') or (first, peek) == ('[', '['): + break + + def take_id_start(self) -> Union[str, None]: + if self.is_char_id_start(self.current_char): + ret = self.current_char + self.next() + return ret + + raise ParseError('E0004', 'a-zA-Z') + + def take_id_char(self) -> Union[str, Literal[False], None]: + def closure(ch: str) -> bool: + cc = ord(ch) + return ((cc >= 97 and cc <= 122) or + (cc >= 65 and cc <= 90) or + (cc >= 48 and cc <= 57) or + cc == 95 or cc == 45) + return self.take_char(closure) + + def take_digit(self) -> Union[str, Literal[False], None]: + def closure(ch: str) -> bool: + cc = ord(ch) + return (cc >= 48 and cc <= 57) + return self.take_char(closure) + + def take_hex_digit(self) -> Union[str, Literal[False], None]: + def closure(ch: str) -> bool: + cc = ord(ch) + return ( + (cc >= 48 and cc <= 57) # 0-9 + or (cc >= 65 and cc <= 70) # A-F + or (cc >= 97 and cc <= 102)) # a-f + return self.take_char(closure) diff --git a/third_party/python/fluent.syntax/fluent/syntax/visitor.py b/third_party/python/fluent.syntax/fluent/syntax/visitor.py new file mode 100644 index 0000000000..0df9f5963e --- /dev/null +++ b/third_party/python/fluent.syntax/fluent/syntax/visitor.py @@ -0,0 +1,65 @@ +from typing import Any, List +from .ast import BaseNode, Node + + +class Visitor: + '''Read-only visitor pattern. + + Subclass this to gather information from an AST. + To generally define which nodes not to descend in to, overload + `generic_visit`. + To handle specific node types, add methods like `visit_Pattern`. + If you want to still descend into the children of the node, call + `generic_visit` of the superclass. + ''' + + def visit(self, node: Any) -> None: + if isinstance(node, list): + for child in node: + self.visit(child) + return + if not isinstance(node, BaseNode): + return + nodename = type(node).__name__ + visit = getattr(self, f'visit_{nodename}', self.generic_visit) + visit(node) + + def generic_visit(self, node: BaseNode) -> None: + for propvalue in vars(node).values(): + self.visit(propvalue) + + +class Transformer(Visitor): + '''In-place AST Transformer pattern. + + Subclass this to create an in-place modified variant + of the given AST. + If you need to keep the original AST around, pass + a `node.clone()` to the transformer. + ''' + + def visit(self, node: Any) -> Any: + if not isinstance(node, BaseNode): + return node + + nodename = type(node).__name__ + visit = getattr(self, f'visit_{nodename}', self.generic_visit) + return visit(node) + + def generic_visit(self, node: Node) -> Node: # type: ignore + for propname, propvalue in vars(node).items(): + if isinstance(propvalue, list): + new_vals: List[Any] = [] + for child in propvalue: + new_val = self.visit(child) + if new_val is not None: + new_vals.append(new_val) + # in-place manipulation + propvalue[:] = new_vals + elif isinstance(propvalue, BaseNode): + new_val = self.visit(propvalue) + if new_val is None: + delattr(node, propname) + else: + setattr(node, propname, new_val) + return node |