From 8fe30fd23dc37ec3516e530a86d1c4b604e71241 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 10 Dec 2023 11:46:01 +0100 Subject: Merging upstream version 20.1.0. Signed-off-by: Daniel Baumann --- sqlglot/tokens.py | 66 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 16 deletions(-) (limited to 'sqlglot/tokens.py') diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 9784c63..e4c3204 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -7,6 +7,9 @@ from sqlglot.errors import TokenError from sqlglot.helper import AutoName from sqlglot.trie import TrieResult, in_trie, new_trie +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + class TokenType(AutoName): L_PAREN = auto() @@ -34,6 +37,7 @@ class TokenType(AutoName): EQ = auto() NEQ = auto() NULLSAFE_EQ = auto() + COLON_EQ = auto() AND = auto() OR = auto() AMP = auto() @@ -56,6 +60,7 @@ class TokenType(AutoName): SESSION_PARAMETER = auto() DAMP = auto() XOR = auto() + DSTAR = auto() BLOCK_START = auto() BLOCK_END = auto() @@ -274,6 +279,7 @@ class TokenType(AutoName): OBJECT_IDENTIFIER = auto() OFFSET = auto() ON = auto() + OPERATOR = auto() ORDER_BY = auto() ORDERED = auto() ORDINALITY = auto() @@ -295,6 +301,7 @@ class TokenType(AutoName): QUOTE = auto() RANGE = auto() RECURSIVE = auto() + REFRESH = auto() REPLACE = auto() RETURNING = auto() REFERENCES = auto() @@ -371,7 +378,7 @@ class Token: col: int = 1, start: int = 0, end: int = 0, - comments: t.List[str] = [], + comments: t.Optional[t.List[str]] = None, ) -> None: """Token initializer. @@ -390,7 +397,7 @@ class Token: self.col = col self.start = start self.end = end - self.comments = comments + self.comments = [] if comments is None else comments def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) @@ -497,11 +504,8 @@ class Tokenizer(metaclass=_Tokenizer): QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] STRING_ESCAPES = ["'"] VAR_SINGLE_TOKENS: t.Set[str] = set() - ESCAPE_SEQUENCES: t.Dict[str, str] = {} # Autofilled - IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False - _COMMENTS: t.Dict[str, str] = {} _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} _IDENTIFIERS: t.Dict[str, str] = {} @@ -523,6 +527,7 @@ class Tokenizer(metaclass=_Tokenizer): "<=": TokenType.LTE, "<>": TokenType.NEQ, "!=": TokenType.NEQ, + ":=": TokenType.COLON_EQ, "<=>": TokenType.NULLSAFE_EQ, "->": TokenType.ARROW, "->>": TokenType.DARROW, @@ -689,17 +694,22 @@ class Tokenizer(metaclass=_Tokenizer): "BOOLEAN": TokenType.BOOLEAN, "BYTE": TokenType.TINYINT, "MEDIUMINT": TokenType.MEDIUMINT, + "INT1": TokenType.TINYINT, "TINYINT": TokenType.TINYINT, + "INT16": TokenType.SMALLINT, "SHORT": TokenType.SMALLINT, "SMALLINT": TokenType.SMALLINT, "INT128": TokenType.INT128, + "HUGEINT": TokenType.INT128, "INT2": TokenType.SMALLINT, "INTEGER": TokenType.INT, "INT": TokenType.INT, "INT4": TokenType.INT, + "INT32": TokenType.INT, + "INT64": TokenType.BIGINT, "LONG": TokenType.BIGINT, "BIGINT": TokenType.BIGINT, - "INT8": TokenType.BIGINT, + "INT8": TokenType.TINYINT, "DEC": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL, "BIGDECIMAL": TokenType.BIGDECIMAL, @@ -781,7 +791,6 @@ class Tokenizer(metaclass=_Tokenizer): "\t": TokenType.SPACE, "\n": TokenType.BREAK, "\r": TokenType.BREAK, - "\r\n": TokenType.BREAK, } COMMANDS = { @@ -803,6 +812,7 @@ class Tokenizer(metaclass=_Tokenizer): "sql", "size", "tokens", + "dialect", "_start", "_current", "_line", @@ -814,7 +824,10 @@ class Tokenizer(metaclass=_Tokenizer): "_prev_token_line", ) - def __init__(self) -> None: + def __init__(self, dialect: DialectType = None) -> None: + from sqlglot.dialects import Dialect + + self.dialect = Dialect.get_or_raise(dialect) self.reset() def reset(self) -> None: @@ -850,13 +863,26 @@ class Tokenizer(metaclass=_Tokenizer): def _scan(self, until: t.Optional[t.Callable] = None) -> None: while self.size and not self._end: - self._start = self._current - self._advance() + current = self._current + + # skip spaces inline rather than iteratively call advance() + # for performance reasons + while current < self.size: + char = self.sql[current] + + if char.isspace() and (char == " " or char == "\t"): + current += 1 + else: + break + + n = current - self._current + self._start = current + self._advance(n if n > 1 else 1) if self._char is None: break - if self._char not in self.WHITE_SPACE: + if not self._char.isspace(): if self._char.isdigit(): self._scan_number() elif self._char in self._IDENTIFIERS: @@ -881,6 +907,10 @@ class Tokenizer(metaclass=_Tokenizer): def _advance(self, i: int = 1, alnum: bool = False) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: + # Ensures we don't count an extra line if we get a \r\n line break sequence + if self._char == "\r" and self._peek == "\n": + i = 2 + self._col = 1 self._line += 1 else: @@ -982,7 +1012,7 @@ class Tokenizer(metaclass=_Tokenizer): if end < self.size: char = self.sql[end] single_token = single_token or char in self.SINGLE_TOKENS - is_space = char in self.WHITE_SPACE + is_space = char.isspace() if not is_space or not prev_space: if is_space: @@ -994,7 +1024,7 @@ class Tokenizer(metaclass=_Tokenizer): skip = True else: char = "" - chars = " " + break if word: if self._scan_string(word): @@ -1086,7 +1116,7 @@ class Tokenizer(metaclass=_Tokenizer): self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") return self._add(token_type, literal) - elif self.IDENTIFIERS_CAN_START_WITH_DIGIT: + elif self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT: return self._add(TokenType.VAR) self._advance(-len(literal)) @@ -1208,8 +1238,12 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") - if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES: - escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek) + if ( + self.dialect.ESCAPE_SEQUENCES + and self._peek + and self._char in self.STRING_ESCAPES + ): + escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek) if escaped_sequence: self._advance(2) text += escaped_sequence -- cgit v1.2.3