diff options
Diffstat (limited to 'sqlglot/tokens.py')
-rw-r--r-- | sqlglot/tokens.py | 163 |
1 files changed, 95 insertions, 68 deletions
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index cf2e31f..64c1f92 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -87,6 +87,7 @@ class TokenType(AutoName): FLOAT = auto() DOUBLE = auto() DECIMAL = auto() + BIGDECIMAL = auto() CHAR = auto() NCHAR = auto() VARCHAR = auto() @@ -214,6 +215,7 @@ class TokenType(AutoName): ISNULL = auto() JOIN = auto() JOIN_MARKER = auto() + KEEP = auto() LANGUAGE = auto() LATERAL = auto() LAZY = auto() @@ -231,6 +233,7 @@ class TokenType(AutoName): MOD = auto() NATURAL = auto() NEXT = auto() + NEXT_VALUE_FOR = auto() NO_ACTION = auto() NOTNULL = auto() NULL = auto() @@ -315,7 +318,7 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col", "comments") + __slots__ = ("token_type", "text", "line", "col", "end", "comments") @classmethod def number(cls, number: int) -> Token: @@ -343,22 +346,29 @@ class Token: text: str, line: int = 1, col: int = 1, + end: int = 0, comments: t.List[str] = [], ) -> None: self.token_type = token_type self.text = text self.line = line - self.col = col - len(text) - self.col = self.col if self.col > 1 else 1 + size = len(text) + self.col = col + self.end = end if end else size self.comments = comments + @property + def start(self) -> int: + """Returns the start of the token.""" + return self.end - len(self.text) + def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) return f"<Token {attributes}>" class _Tokenizer(type): - def __new__(cls, clsname, bases, attrs): # type: ignore + def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) klass._QUOTES = { @@ -433,25 +443,25 @@ class Tokenizer(metaclass=_Tokenizer): "#": TokenType.HASH, } - QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] - BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] - - HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] - BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - + HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - - STRING_ESCAPES = ["'"] - - _STRING_ESCAPES: t.Set[str] = set() - IDENTIFIER_ESCAPES = ['"'] + QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] + STRING_ESCAPES = ["'"] + VAR_SINGLE_TOKENS: t.Set[str] = set() + _COMMENTS: t.Dict[str, str] = {} + _BIT_STRINGS: t.Dict[str, str] = {} + _BYTE_STRINGS: t.Dict[str, str] = {} + _HEX_STRINGS: t.Dict[str, str] = {} + _IDENTIFIERS: t.Dict[str, str] = {} _IDENTIFIER_ESCAPES: t.Set[str] = set() + _QUOTES: t.Dict[str, str] = {} + _STRING_ESCAPES: t.Set[str] = set() - KEYWORDS = { + KEYWORDS: t.Dict[t.Optional[str], TokenType] = { **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")}, "{{+": TokenType.BLOCK_START, @@ -553,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer): "IS": TokenType.IS, "ISNULL": TokenType.ISNULL, "JOIN": TokenType.JOIN, + "KEEP": TokenType.KEEP, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, "LEADING": TokenType.LEADING, @@ -565,6 +576,7 @@ class Tokenizer(metaclass=_Tokenizer): "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, + "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR, "NO ACTION": TokenType.NO_ACTION, "NOT": TokenType.NOT, "NOTNULL": TokenType.NOTNULL, @@ -632,6 +644,7 @@ class Tokenizer(metaclass=_Tokenizer): "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, + "UUID": TokenType.UUID, "VALUES": TokenType.VALUES, "VIEW": TokenType.VIEW, "VOLATILE": TokenType.VOLATILE, @@ -661,6 +674,8 @@ class Tokenizer(metaclass=_Tokenizer): "INT8": TokenType.BIGINT, "DEC": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL, + "BIGDECIMAL": TokenType.BIGDECIMAL, + "BIGNUMERIC": TokenType.BIGDECIMAL, "MAP": TokenType.MAP, "NULLABLE": TokenType.NULLABLE, "NUMBER": TokenType.DECIMAL, @@ -742,7 +757,7 @@ class Tokenizer(metaclass=_Tokenizer): ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")] - KEYWORD_TRIE = None # autofilled + KEYWORD_TRIE: t.Dict = {} # autofilled IDENTIFIER_CAN_START_WITH_DIGIT = False @@ -776,19 +791,28 @@ class Tokenizer(metaclass=_Tokenizer): self._col = 1 self._comments: t.List[str] = [] - self._char = None - self._end = None - self._peek = None + self._char = "" + self._end = False + self._peek = "" self._prev_token_line = -1 self._prev_token_comments: t.List[str] = [] - self._prev_token_type = None + self._prev_token_type: t.Optional[TokenType] = None def tokenize(self, sql: str) -> t.List[Token]: """Returns a list of tokens corresponding to the SQL string `sql`.""" self.reset() self.sql = sql self.size = len(sql) - self._scan() + try: + self._scan() + except Exception as e: + start = self._current - 50 + end = self._current + 50 + start = start if start > 0 else 0 + end = end if end < self.size else self.size - 1 + context = self.sql[start:end] + raise ValueError(f"Error tokenizing '{context}'") from e + return self.tokens def _scan(self, until: t.Optional[t.Callable] = None) -> None: @@ -810,9 +834,12 @@ class Tokenizer(metaclass=_Tokenizer): if until and until(): break + if self.tokens: + self.tokens[-1].comments.extend(self._comments) + def _chars(self, size: int) -> str: if size == 1: - return self._char # type: ignore + return self._char start = self._current - 1 end = start + size if end <= self.size: @@ -821,17 +848,15 @@ class Tokenizer(metaclass=_Tokenizer): def _advance(self, i: int = 1) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: - self._set_new_line() + self._col = 1 + self._line += 1 + else: + self._col += i - self._col += i self._current += i - self._end = self._current >= self.size # type: ignore - self._char = self.sql[self._current - 1] # type: ignore - self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore - - def _set_new_line(self) -> None: - self._col = 1 - self._line += 1 + self._end = self._current >= self.size + self._char = self.sql[self._current - 1] + self._peek = "" if self._end else self.sql[self._current] @property def _text(self) -> str: @@ -840,13 +865,14 @@ class Tokenizer(metaclass=_Tokenizer): def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line self._prev_token_comments = self._comments - self._prev_token_type = token_type # type: ignore + self._prev_token_type = token_type self.tokens.append( Token( token_type, self._text if text is None else text, self._line, self._col, + self._current, self._comments, ) ) @@ -881,7 +907,7 @@ class Tokenizer(metaclass=_Tokenizer): if skip: result = 1 else: - result, trie = in_trie(trie, char.upper()) # type: ignore + result, trie = in_trie(trie, char.upper()) if result == 0: break @@ -910,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer): if not word: if self._char in self.SINGLE_TOKENS: - self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore + self._add(self.SINGLE_TOKENS[self._char], text=self._char) return self._scan_var() return @@ -927,29 +953,31 @@ class Tokenizer(metaclass=_Tokenizer): self._add(self.KEYWORDS[word], text=word) def _scan_comment(self, comment_start: str) -> bool: - if comment_start not in self._COMMENTS: # type: ignore + if comment_start not in self._COMMENTS: return False comment_start_line = self._line comment_start_size = len(comment_start) - comment_end = self._COMMENTS[comment_start] # type: ignore + comment_end = self._COMMENTS[comment_start] if comment_end: - comment_end_size = len(comment_end) + # Skip the comment's start delimiter + self._advance(comment_start_size) + comment_end_size = len(comment_end) while not self._end and self._chars(comment_end_size) != comment_end: self._advance() - self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore + self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) self._advance(comment_end_size - 1) else: while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK: self._advance() - self._comments.append(self._text[comment_start_size:]) # type: ignore + self._comments.append(self._text[comment_start_size:]) # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. # Multiple consecutive comments are preserved by appending them to the current comments list. - if comment_start_line == self._prev_token_line or self._end: + if comment_start_line == self._prev_token_line: self.tokens[-1].comments.extend(self._comments) self._comments = [] self._prev_token_line = self._line @@ -958,7 +986,7 @@ class Tokenizer(metaclass=_Tokenizer): def _scan_number(self) -> None: if self._char == "0": - peek = self._peek.upper() # type: ignore + peek = self._peek.upper() if peek == "B": return self._scan_bits() elif peek == "X": @@ -968,7 +996,7 @@ class Tokenizer(metaclass=_Tokenizer): scientific = 0 while True: - if self._peek.isdigit(): # type: ignore + if self._peek.isdigit(): self._advance() elif self._peek == "." and not decimal: decimal = True @@ -976,24 +1004,23 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() - elif self._peek.upper() == "E" and not scientific: # type: ignore + elif self._peek.upper() == "E" and not scientific: scientific += 1 self._advance() - elif self._peek.isidentifier(): # type: ignore + elif self._peek.isidentifier(): number_text = self._text - literal = [] + literal = "" - while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore - literal.append(self._peek.upper()) # type: ignore + while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: + literal += self._peek.upper() self._advance() - literal = "".join(literal) # type: ignore - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) if token_type: self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") - return self._add(token_type, literal) # type: ignore + return self._add(token_type, literal) elif self.IDENTIFIER_CAN_START_WITH_DIGIT: return self._add(TokenType.VAR) @@ -1020,7 +1047,7 @@ class Tokenizer(metaclass=_Tokenizer): def _extract_value(self) -> str: while True: - char = self._peek.strip() # type: ignore + char = self._peek.strip() if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -1029,35 +1056,35 @@ class Tokenizer(metaclass=_Tokenizer): return self._text def _scan_string(self, quote: str) -> bool: - quote_end = self._QUOTES.get(quote) # type: ignore + quote_end = self._QUOTES.get(quote) if quote_end is None: return False self._advance(len(quote)) text = self._extract_string(quote_end) - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. def _scan_formatted_string(self, string_start: str) -> bool: - if string_start in self._HEX_STRINGS: # type: ignore - delimiters = self._HEX_STRINGS # type: ignore + if string_start in self._HEX_STRINGS: + delimiters = self._HEX_STRINGS token_type = TokenType.HEX_STRING base = 16 - elif string_start in self._BIT_STRINGS: # type: ignore - delimiters = self._BIT_STRINGS # type: ignore + elif string_start in self._BIT_STRINGS: + delimiters = self._BIT_STRINGS token_type = TokenType.BIT_STRING base = 2 - elif string_start in self._BYTE_STRINGS: # type: ignore - delimiters = self._BYTE_STRINGS # type: ignore + elif string_start in self._BYTE_STRINGS: + delimiters = self._BYTE_STRINGS token_type = TokenType.BYTE_STRING base = None else: return False self._advance(len(string_start)) - string_end = delimiters.get(string_start) + string_end = delimiters[string_start] text = self._extract_string(string_end) if base is None: @@ -1083,20 +1110,20 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() if self._char == identifier_end: if identifier_end_is_escape and self._peek == identifier_end: - text += identifier_end # type: ignore + text += identifier_end self._advance() continue break - text += self._char # type: ignore + text += self._char self._add(TokenType.IDENTIFIER, text) def _scan_var(self) -> None: while True: - char = self._peek.strip() # type: ignore - if char and char not in self.SINGLE_TOKENS: + char = self._peek.strip() + if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS): self._advance() else: break @@ -1115,9 +1142,9 @@ class Tokenizer(metaclass=_Tokenizer): self._peek == delimiter or self._peek in self._STRING_ESCAPES ): if self._peek == delimiter: - text += self._peek # type: ignore + text += self._peek else: - text += self._char + self._peek # type: ignore + text += self._char + self._peek if self._current + 1 < self.size: self._advance(2) @@ -1131,7 +1158,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") - text += self._char # type: ignore + text += self._char self._advance() return text |