summaryrefslogtreecommitdiffstats
path: root/sqlglot/tokens.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/tokens.py')
-rw-r--r--sqlglot/tokens.py247
1 files changed, 150 insertions, 97 deletions
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 766c01a..95d84d6 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import typing as t
from enum import auto
from sqlglot.helper import AutoName
@@ -27,6 +30,7 @@ class TokenType(AutoName):
NOT = auto()
EQ = auto()
NEQ = auto()
+ NULLSAFE_EQ = auto()
AND = auto()
OR = auto()
AMP = auto()
@@ -36,12 +40,14 @@ class TokenType(AutoName):
TILDA = auto()
ARROW = auto()
DARROW = auto()
+ FARROW = auto()
+ HASH = auto()
HASH_ARROW = auto()
DHASH_ARROW = auto()
LR_ARROW = auto()
- ANNOTATION = auto()
DOLLAR = auto()
PARAMETER = auto()
+ SESSION_PARAMETER = auto()
SPACE = auto()
BREAK = auto()
@@ -73,7 +79,7 @@ class TokenType(AutoName):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
- BYTEA = auto()
+ VARBINARY = auto()
JSON = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
@@ -142,6 +148,7 @@ class TokenType(AutoName):
DESCRIBE = auto()
DETERMINISTIC = auto()
DISTINCT = auto()
+ DISTINCT_FROM = auto()
DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
@@ -238,6 +245,7 @@ class TokenType(AutoName):
RETURNS = auto()
RIGHT = auto()
RLIKE = auto()
+ ROLLBACK = auto()
ROLLUP = auto()
ROW = auto()
ROWS = auto()
@@ -287,37 +295,49 @@ class TokenType(AutoName):
class Token:
- __slots__ = ("token_type", "text", "line", "col")
+ __slots__ = ("token_type", "text", "line", "col", "comment")
@classmethod
- def number(cls, number):
+ def number(cls, number: int) -> Token:
+ """Returns a NUMBER token with `number` as its text."""
return cls(TokenType.NUMBER, str(number))
@classmethod
- def string(cls, string):
+ def string(cls, string: str) -> Token:
+ """Returns a STRING token with `string` as its text."""
return cls(TokenType.STRING, string)
@classmethod
- def identifier(cls, identifier):
+ def identifier(cls, identifier: str) -> Token:
+ """Returns an IDENTIFIER token with `identifier` as its text."""
return cls(TokenType.IDENTIFIER, identifier)
@classmethod
- def var(cls, var):
+ def var(cls, var: str) -> Token:
+ """Returns an VAR token with `var` as its text."""
return cls(TokenType.VAR, var)
- def __init__(self, token_type, text, line=1, col=1):
+ def __init__(
+ self,
+ token_type: TokenType,
+ text: str,
+ line: int = 1,
+ col: int = 1,
+ comment: t.Optional[str] = None,
+ ) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
+ self.comment = comment
- def __repr__(self):
+ def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
- def __new__(cls, clsname, bases, attrs):
+ def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
@@ -325,27 +345,29 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
+ klass._ESCAPES = set(klass.ESCAPES)
klass._COMMENTS = dict(
- (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
+ (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
+ for comment in klass.COMMENTS
)
klass.KEYWORD_TRIE = new_trie(
key.upper()
- for key, value in {
+ for key in {
**klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
- }.items()
+ }
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
@staticmethod
- def _delimeter_list_to_dict(list):
+ def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
@@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer):
"*": TokenType.STAR,
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
- "#": TokenType.ANNOTATION,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
+ "#": TokenType.HASH,
}
- QUOTES = ["'"]
+ QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
- BIT_STRINGS = []
+ BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
- HEX_STRINGS = []
+ HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
- BYTE_STRINGS = []
+ BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
- IDENTIFIERS = ['"']
+ IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
- ESCAPE = "'"
+ ESCAPES = ["'"]
KEYWORDS = {
"/*+": TokenType.HINT,
@@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer):
"<=": TokenType.LTE,
"<>": TokenType.NEQ,
"!=": TokenType.NEQ,
+ "<=>": TokenType.NULLSAFE_EQ,
"->": TokenType.ARROW,
"->>": TokenType.DARROW,
+ "=>": TokenType.FARROW,
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
@@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DESCRIBE": TokenType.DESCRIBE,
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
+ "DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
@@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
+ "ROLLBACK": TokenType.ROLLBACK,
"ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
@@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
"BINARY": TokenType.BINARY,
- "BLOB": TokenType.BINARY,
- "BYTEA": TokenType.BINARY,
+ "BLOB": TokenType.VARBINARY,
+ "BYTEA": TokenType.VARBINARY,
+ "VARBINARY": TokenType.VARBINARY,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SET,
TokenType.SHOW,
TokenType.TRUNCATE,
- TokenType.USE,
TokenType.VACUUM,
+ TokenType.ROLLBACK,
}
# handle numeric literals like in hive (3L = BIGINT)
- NUMERIC_LITERALS = {}
- ENCODE = None
+ NUMERIC_LITERALS: t.Dict[str, str] = {}
+ ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled
@@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
+ "_comment",
"_char",
"_end",
"_peek",
+ "_prev_token_line",
+ "_prev_token_comment",
"_prev_token_type",
+ "_replace_backslash",
)
- def __init__(self):
- """
- Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
- """
+ def __init__(self) -> None:
+ self._replace_backslash = "\\" in self._ESCAPES # type: ignore
self.reset()
- def reset(self):
+ def reset(self) -> None:
self.sql = ""
self.size = 0
- self.tokens = []
+ self.tokens: t.List[Token] = []
self._start = 0
self._current = 0
self._line = 1
self._col = 1
+ self._comment = None
self._char = None
self._end = None
self._peek = None
+ self._prev_token_line = -1
+ self._prev_token_comment = None
self._prev_token_type = None
- def tokenize(self, sql):
+ def tokenize(self, sql: str) -> t.List[Token]:
+ """Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset()
self.sql = sql
self.size = len(sql)
@@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer):
if not self._char:
break
- white_space = self.WHITE_SPACE.get(self._char)
- identifier_end = self._IDENTIFIERS.get(self._char)
+ white_space = self.WHITE_SPACE.get(self._char) # type: ignore
+ identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
- elif self._char.isdigit():
+ elif self._char.isdigit(): # type:ignore
self._scan_number()
elif identifier_end:
self._scan_identifier(identifier_end)
@@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_keywords()
return self.tokens
- def _chars(self, size):
+ def _chars(self, size: int) -> str:
if size == 1:
- return self._char
+ return self._char # type: ignore
start = self._current - 1
end = start + size
if end <= self.size:
return self.sql[start:end]
return ""
- def _advance(self, i=1):
+ def _advance(self, i: int = 1) -> None:
self._col += i
self._current += i
- self._end = self._current >= self.size
- self._char = self.sql[self._current - 1]
- self._peek = self.sql[self._current] if self._current < self.size else ""
+ self._end = self._current >= self.size # type: ignore
+ self._char = self.sql[self._current - 1] # type: ignore
+ self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
@property
- def _text(self):
+ def _text(self) -> str:
return self.sql[self._start : self._current]
- def _add(self, token_type, text=None):
- self._prev_token_type = token_type
- self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
+ def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
+ self._prev_token_line = self._line
+ self._prev_token_comment = self._comment
+ self._prev_token_type = token_type # type: ignore
+ self.tokens.append(
+ Token(
+ token_type,
+ self._text if text is None else text,
+ self._line,
+ self._col,
+ self._comment,
+ )
+ )
+ self._comment = None
- if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
+ if token_type in self.COMMANDS and (
+ len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
+ ):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
if self._start < self._current:
self._add(TokenType.STRING)
- def _scan_keywords(self):
+ def _scan_keywords(self) -> None:
size = 0
word = None
chars = self._text
@@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip:
result = 1
else:
- result, trie = in_trie(trie, char.upper())
+ result, trie = in_trie(trie, char.upper()) # type: ignore
if result == 0:
break
@@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer):
else:
skip = True
else:
- chars = None
+ chars = None # type: ignore
if not word:
if self._char in self.SINGLE_TOKENS:
- token = self.SINGLE_TOKENS[self._char]
- if token == TokenType.ANNOTATION:
- self._scan_annotation()
- return
- self._add(token)
+ self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
return
self._scan_var()
return
@@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance(size - 1)
self._add(self.KEYWORDS[word.upper()])
- def _scan_comment(self, comment_start):
- if comment_start not in self._COMMENTS:
+ def _scan_comment(self, comment_start: str) -> bool:
+ if comment_start not in self._COMMENTS: # type: ignore
return False
- comment_end = self._COMMENTS[comment_start]
+ comment_start_line = self._line
+ comment_start_size = len(comment_start)
+ comment_end = self._COMMENTS[comment_start] # type: ignore
if comment_end:
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
+
+ self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._advance(comment_end_size - 1)
else:
- while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
+ while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
- return True
+ self._comment = self._text[comment_start_size:] # type: ignore
- def _scan_annotation(self):
- while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
- self._advance()
- self._add(TokenType.ANNOTATION, self._text[1:])
+ # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
+ # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
- def _scan_number(self):
+ if comment_start_line == self._prev_token_line:
+ if self._prev_token_comment is None:
+ self.tokens[-1].comment = self._comment
+
+ self._comment = None
+
+ return True
+
+ def _scan_number(self) -> None:
if self._char == "0":
- peek = self._peek.upper()
+ peek = self._peek.upper() # type: ignore
if peek == "B":
return self._scan_bits()
elif peek == "X":
@@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0
while True:
- if self._peek.isdigit():
+ if self._peek.isdigit(): # type: ignore
self._advance()
elif self._peek == "." and not decimal:
decimal = True
@@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
- elif self._peek.upper() == "E" and not scientific:
+ elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1
self._advance()
- elif self._peek.isalpha():
+ elif self._peek.isalpha(): # type: ignore
self._add(TokenType.NUMBER)
literal = []
- while self._peek.isalpha():
- literal.append(self._peek.upper())
+ while self._peek.isalpha(): # type: ignore
+ literal.append(self._peek.upper()) # type: ignore
self._advance()
- literal = "".join(literal)
- token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
+ literal = "".join(literal) # type: ignore
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
if token_type:
self._add(TokenType.DCOLON, "::")
- return self._add(token_type, literal)
+ return self._add(token_type, literal) # type: ignore
return self._advance(-len(literal))
else:
return self._add(TokenType.NUMBER)
- def _scan_bits(self):
+ def _scan_bits(self) -> None:
self._advance()
value = self._extract_value()
try:
@@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError:
self._add(TokenType.IDENTIFIER)
- def _scan_hex(self):
+ def _scan_hex(self) -> None:
self._advance()
value = self._extract_value()
try:
@@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError:
self._add(TokenType.IDENTIFIER)
- def _extract_value(self):
+ def _extract_value(self) -> str:
while True:
- char = self._peek.strip()
+ char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
- def _scan_string(self, quote):
- quote_end = self._QUOTES.get(quote)
+ def _scan_string(self, quote: str) -> bool:
+ quote_end = self._QUOTES.get(quote) # type: ignore
if quote_end is None:
return False
self._advance(len(quote))
text = self._extract_string(quote_end)
-
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
- text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
+ text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
+ text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
- def _scan_formatted_string(self, string_start):
- if string_start in self._HEX_STRINGS:
- delimiters = self._HEX_STRINGS
+ def _scan_formatted_string(self, string_start: str) -> bool:
+ if string_start in self._HEX_STRINGS: # type: ignore
+ delimiters = self._HEX_STRINGS # type: ignore
token_type = TokenType.HEX_STRING
base = 16
- elif string_start in self._BIT_STRINGS:
- delimiters = self._BIT_STRINGS
+ elif string_start in self._BIT_STRINGS: # type: ignore
+ delimiters = self._BIT_STRINGS # type: ignore
token_type = TokenType.BIT_STRING
base = 2
- elif string_start in self._BYTE_STRINGS:
- delimiters = self._BYTE_STRINGS
+ elif string_start in self._BYTE_STRINGS: # type: ignore
+ delimiters = self._BYTE_STRINGS # type: ignore
token_type = TokenType.BYTE_STRING
base = None
else:
@@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer):
try:
self._add(token_type, f"{int(text, base)}")
except:
- raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
+ raise RuntimeError(
+ f"Numeric string contains invalid characters from {self._line}:{self._start}"
+ )
return True
- def _scan_identifier(self, identifier_end):
+ def _scan_identifier(self, identifier_end: str) -> None:
while self._peek != identifier_end:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
@@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
- def _scan_var(self):
+ def _scan_var(self) -> None:
while True:
- char = self._peek.strip()
+ char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer):
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
- def _extract_string(self, delimiter):
+ def _extract_string(self, delimiter: str) -> str:
text = ""
delim_size = len(delimiter)
while True:
- if self._char == self.ESCAPE and self._peek == delimiter:
+ if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
text += delimiter
self._advance(2)
else:
@@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
- text += self._char
+ text += self._char # type: ignore
self._advance()
return text