diff options
Diffstat (limited to 'sqlglot/tokens.py')
-rw-r--r-- | sqlglot/tokens.py | 247 |
1 files changed, 150 insertions, 97 deletions
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 766c01a..95d84d6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from enum import auto from sqlglot.helper import AutoName @@ -27,6 +30,7 @@ class TokenType(AutoName): NOT = auto() EQ = auto() NEQ = auto() + NULLSAFE_EQ = auto() AND = auto() OR = auto() AMP = auto() @@ -36,12 +40,14 @@ class TokenType(AutoName): TILDA = auto() ARROW = auto() DARROW = auto() + FARROW = auto() + HASH = auto() HASH_ARROW = auto() DHASH_ARROW = auto() LR_ARROW = auto() - ANNOTATION = auto() DOLLAR = auto() PARAMETER = auto() + SESSION_PARAMETER = auto() SPACE = auto() BREAK = auto() @@ -73,7 +79,7 @@ class TokenType(AutoName): NVARCHAR = auto() TEXT = auto() BINARY = auto() - BYTEA = auto() + VARBINARY = auto() JSON = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() @@ -142,6 +148,7 @@ class TokenType(AutoName): DESCRIBE = auto() DETERMINISTIC = auto() DISTINCT = auto() + DISTINCT_FROM = auto() DISTRIBUTE_BY = auto() DIV = auto() DROP = auto() @@ -238,6 +245,7 @@ class TokenType(AutoName): RETURNS = auto() RIGHT = auto() RLIKE = auto() + ROLLBACK = auto() ROLLUP = auto() ROW = auto() ROWS = auto() @@ -287,37 +295,49 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col") + __slots__ = ("token_type", "text", "line", "col", "comment") @classmethod - def number(cls, number): + def number(cls, number: int) -> Token: + """Returns a NUMBER token with `number` as its text.""" return cls(TokenType.NUMBER, str(number)) @classmethod - def string(cls, string): + def string(cls, string: str) -> Token: + """Returns a STRING token with `string` as its text.""" return cls(TokenType.STRING, string) @classmethod - def identifier(cls, identifier): + def identifier(cls, identifier: str) -> Token: + """Returns an IDENTIFIER token with `identifier` as its text.""" return cls(TokenType.IDENTIFIER, identifier) @classmethod - def var(cls, var): + def var(cls, var: str) -> Token: + """Returns an VAR token with `var` as its text.""" return cls(TokenType.VAR, var) - def __init__(self, token_type, text, line=1, col=1): + def __init__( + self, + token_type: TokenType, + text: str, + line: int = 1, + col: int = 1, + comment: t.Optional[str] = None, + ) -> None: self.token_type = token_type self.text = text self.line = line self.col = max(col - len(text), 1) + self.comment = comment - def __repr__(self): + def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) return f"<Token {attributes}>" class _Tokenizer(type): - def __new__(cls, clsname, bases, attrs): + def __new__(cls, clsname, bases, attrs): # type: ignore klass = super().__new__(cls, clsname, bases, attrs) klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) @@ -325,27 +345,29 @@ class _Tokenizer(type): klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) + klass._ESCAPES = set(klass.ESCAPES) klass._COMMENTS = dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) + for comment in klass.COMMENTS ) klass.KEYWORD_TRIE = new_trie( key.upper() - for key, value in { + for key in { **klass.KEYWORDS, **{comment: TokenType.COMMENT for comment in klass._COMMENTS}, **{quote: TokenType.QUOTE for quote in klass._QUOTES}, **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS}, - }.items() + } if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) return klass @staticmethod - def _delimeter_list_to_dict(list): + def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) @@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer): "*": TokenType.STAR, "~": TokenType.TILDA, "?": TokenType.PLACEHOLDER, - "#": TokenType.ANNOTATION, "@": TokenType.PARAMETER, # used for breaking a var like x'y' but nothing else # the token type doesn't matter "'": TokenType.QUOTE, "`": TokenType.IDENTIFIER, '"': TokenType.IDENTIFIER, + "#": TokenType.HASH, } - QUOTES = ["'"] + QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] - BIT_STRINGS = [] + BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] - HEX_STRINGS = [] + HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] - BYTE_STRINGS = [] + BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - IDENTIFIERS = ['"'] + IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - ESCAPE = "'" + ESCAPES = ["'"] KEYWORDS = { "/*+": TokenType.HINT, @@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer): "<=": TokenType.LTE, "<>": TokenType.NEQ, "!=": TokenType.NEQ, + "<=>": TokenType.NULLSAFE_EQ, "->": TokenType.ARROW, "->>": TokenType.DARROW, + "=>": TokenType.FARROW, "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, @@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer): "DESCRIBE": TokenType.DESCRIBE, "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, + "DISTINCT FROM": TokenType.DISTINCT_FROM, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DIV": TokenType.DIV, "DROP": TokenType.DROP, @@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer): "RETURNS": TokenType.RETURNS, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, + "ROLLBACK": TokenType.ROLLBACK, "ROLLUP": TokenType.ROLLUP, "ROW": TokenType.ROW, "ROWS": TokenType.ROWS, @@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer): "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, "BINARY": TokenType.BINARY, - "BLOB": TokenType.BINARY, - "BYTEA": TokenType.BINARY, + "BLOB": TokenType.VARBINARY, + "BYTEA": TokenType.VARBINARY, + "VARBINARY": TokenType.VARBINARY, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, @@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.SET, TokenType.SHOW, TokenType.TRUNCATE, - TokenType.USE, TokenType.VACUUM, + TokenType.ROLLBACK, } # handle numeric literals like in hive (3L = BIGINT) - NUMERIC_LITERALS = {} - ENCODE = None + NUMERIC_LITERALS: t.Dict[str, str] = {} + ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/")] KEYWORD_TRIE = None # autofilled @@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer): "_current", "_line", "_col", + "_comment", "_char", "_end", "_peek", + "_prev_token_line", + "_prev_token_comment", "_prev_token_type", + "_replace_backslash", ) - def __init__(self): - """ - Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token` - """ + def __init__(self) -> None: + self._replace_backslash = "\\" in self._ESCAPES # type: ignore self.reset() - def reset(self): + def reset(self) -> None: self.sql = "" self.size = 0 - self.tokens = [] + self.tokens: t.List[Token] = [] self._start = 0 self._current = 0 self._line = 1 self._col = 1 + self._comment = None self._char = None self._end = None self._peek = None + self._prev_token_line = -1 + self._prev_token_comment = None self._prev_token_type = None - def tokenize(self, sql): + def tokenize(self, sql: str) -> t.List[Token]: + """Returns a list of tokens corresponding to the SQL string `sql`.""" self.reset() self.sql = sql self.size = len(sql) @@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer): if not self._char: break - white_space = self.WHITE_SPACE.get(self._char) - identifier_end = self._IDENTIFIERS.get(self._char) + white_space = self.WHITE_SPACE.get(self._char) # type: ignore + identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore if white_space: if white_space == TokenType.BREAK: self._col = 1 self._line += 1 - elif self._char.isdigit(): + elif self._char.isdigit(): # type:ignore self._scan_number() elif identifier_end: self._scan_identifier(identifier_end) @@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer): self._scan_keywords() return self.tokens - def _chars(self, size): + def _chars(self, size: int) -> str: if size == 1: - return self._char + return self._char # type: ignore start = self._current - 1 end = start + size if end <= self.size: return self.sql[start:end] return "" - def _advance(self, i=1): + def _advance(self, i: int = 1) -> None: self._col += i self._current += i - self._end = self._current >= self.size - self._char = self.sql[self._current - 1] - self._peek = self.sql[self._current] if self._current < self.size else "" + self._end = self._current >= self.size # type: ignore + self._char = self.sql[self._current - 1] # type: ignore + self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore @property - def _text(self): + def _text(self) -> str: return self.sql[self._start : self._current] - def _add(self, token_type, text=None): - self._prev_token_type = token_type - self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col)) + def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: + self._prev_token_line = self._line + self._prev_token_comment = self._comment + self._prev_token_type = token_type # type: ignore + self.tokens.append( + Token( + token_type, + self._text if text is None else text, + self._line, + self._col, + self._comment, + ) + ) + self._comment = None - if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): + if token_type in self.COMMANDS and ( + len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON + ): self._start = self._current while not self._end and self._peek != ";": self._advance() if self._start < self._current: self._add(TokenType.STRING) - def _scan_keywords(self): + def _scan_keywords(self) -> None: size = 0 word = None chars = self._text @@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer): if skip: result = 1 else: - result, trie = in_trie(trie, char.upper()) + result, trie = in_trie(trie, char.upper()) # type: ignore if result == 0: break @@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer): else: skip = True else: - chars = None + chars = None # type: ignore if not word: if self._char in self.SINGLE_TOKENS: - token = self.SINGLE_TOKENS[self._char] - if token == TokenType.ANNOTATION: - self._scan_annotation() - return - self._add(token) + self._add(self.SINGLE_TOKENS[self._char]) # type: ignore return self._scan_var() return @@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer): self._advance(size - 1) self._add(self.KEYWORDS[word.upper()]) - def _scan_comment(self, comment_start): - if comment_start not in self._COMMENTS: + def _scan_comment(self, comment_start: str) -> bool: + if comment_start not in self._COMMENTS: # type: ignore return False - comment_end = self._COMMENTS[comment_start] + comment_start_line = self._line + comment_start_size = len(comment_start) + comment_end = self._COMMENTS[comment_start] # type: ignore if comment_end: comment_end_size = len(comment_end) while not self._end and self._chars(comment_end_size) != comment_end: self._advance() + + self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore self._advance(comment_end_size - 1) else: - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore self._advance() - return True + self._comment = self._text[comment_start_size:] # type: ignore - def _scan_annotation(self): - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",": - self._advance() - self._add(TokenType.ANNOTATION, self._text[1:]) + # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both + # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one. - def _scan_number(self): + if comment_start_line == self._prev_token_line: + if self._prev_token_comment is None: + self.tokens[-1].comment = self._comment + + self._comment = None + + return True + + def _scan_number(self) -> None: if self._char == "0": - peek = self._peek.upper() + peek = self._peek.upper() # type: ignore if peek == "B": return self._scan_bits() elif peek == "X": @@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer): scientific = 0 while True: - if self._peek.isdigit(): + if self._peek.isdigit(): # type: ignore self._advance() elif self._peek == "." and not decimal: decimal = True @@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() - elif self._peek.upper() == "E" and not scientific: + elif self._peek.upper() == "E" and not scientific: # type: ignore scientific += 1 self._advance() - elif self._peek.isalpha(): + elif self._peek.isalpha(): # type: ignore self._add(TokenType.NUMBER) literal = [] - while self._peek.isalpha(): - literal.append(self._peek.upper()) + while self._peek.isalpha(): # type: ignore + literal.append(self._peek.upper()) # type: ignore self._advance() - literal = "".join(literal) - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) + literal = "".join(literal) # type: ignore + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore if token_type: self._add(TokenType.DCOLON, "::") - return self._add(token_type, literal) + return self._add(token_type, literal) # type: ignore return self._advance(-len(literal)) else: return self._add(TokenType.NUMBER) - def _scan_bits(self): + def _scan_bits(self) -> None: self._advance() value = self._extract_value() try: @@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer): except ValueError: self._add(TokenType.IDENTIFIER) - def _scan_hex(self): + def _scan_hex(self) -> None: self._advance() value = self._extract_value() try: @@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer): except ValueError: self._add(TokenType.IDENTIFIER) - def _extract_value(self): + def _extract_value(self) -> str: while True: - char = self._peek.strip() + char = self._peek.strip() # type: ignore if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer): return self._text - def _scan_string(self, quote): - quote_end = self._QUOTES.get(quote) + def _scan_string(self, quote: str) -> bool: + quote_end = self._QUOTES.get(quote) # type: ignore if quote_end is None: return False self._advance(len(quote)) text = self._extract_string(quote_end) - - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text - text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore + text = text.replace("\\\\", "\\") if self._replace_backslash else text self._add(TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. - def _scan_formatted_string(self, string_start): - if string_start in self._HEX_STRINGS: - delimiters = self._HEX_STRINGS + def _scan_formatted_string(self, string_start: str) -> bool: + if string_start in self._HEX_STRINGS: # type: ignore + delimiters = self._HEX_STRINGS # type: ignore token_type = TokenType.HEX_STRING base = 16 - elif string_start in self._BIT_STRINGS: - delimiters = self._BIT_STRINGS + elif string_start in self._BIT_STRINGS: # type: ignore + delimiters = self._BIT_STRINGS # type: ignore token_type = TokenType.BIT_STRING base = 2 - elif string_start in self._BYTE_STRINGS: - delimiters = self._BYTE_STRINGS + elif string_start in self._BYTE_STRINGS: # type: ignore + delimiters = self._BYTE_STRINGS # type: ignore token_type = TokenType.BYTE_STRING base = None else: @@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer): try: self._add(token_type, f"{int(text, base)}") except: - raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + raise RuntimeError( + f"Numeric string contains invalid characters from {self._line}:{self._start}" + ) return True - def _scan_identifier(self, identifier_end): + def _scan_identifier(self, identifier_end: str) -> None: while self._peek != identifier_end: if self._end: raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") @@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() self._add(TokenType.IDENTIFIER, self._text[1:-1]) - def _scan_var(self): + def _scan_var(self) -> None: while True: - char = self._peek.strip() + char = self._peek.strip() # type: ignore if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer): else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) ) - def _extract_string(self, delimiter): + def _extract_string(self, delimiter: str) -> str: text = "" delim_size = len(delimiter) while True: - if self._char == self.ESCAPE and self._peek == delimiter: + if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore text += delimiter self._advance(2) else: @@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") - text += self._char + text += self._char # type: ignore self._advance() return text |