summaryrefslogtreecommitdiffstats
path: root/sqlglot/tokens.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/tokens.py')
-rw-r--r--sqlglot/tokens.py163
1 files changed, 95 insertions, 68 deletions
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index cf2e31f..64c1f92 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -87,6 +87,7 @@ class TokenType(AutoName):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
+ BIGDECIMAL = auto()
CHAR = auto()
NCHAR = auto()
VARCHAR = auto()
@@ -214,6 +215,7 @@ class TokenType(AutoName):
ISNULL = auto()
JOIN = auto()
JOIN_MARKER = auto()
+ KEEP = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
@@ -231,6 +233,7 @@ class TokenType(AutoName):
MOD = auto()
NATURAL = auto()
NEXT = auto()
+ NEXT_VALUE_FOR = auto()
NO_ACTION = auto()
NOTNULL = auto()
NULL = auto()
@@ -315,7 +318,7 @@ class TokenType(AutoName):
class Token:
- __slots__ = ("token_type", "text", "line", "col", "comments")
+ __slots__ = ("token_type", "text", "line", "col", "end", "comments")
@classmethod
def number(cls, number: int) -> Token:
@@ -343,22 +346,29 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
+ end: int = 0,
comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
- self.col = col - len(text)
- self.col = self.col if self.col > 1 else 1
+ size = len(text)
+ self.col = col
+ self.end = end if end else size
self.comments = comments
+ @property
+ def start(self) -> int:
+ """Returns the start of the token."""
+ return self.end - len(self.text)
+
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
- def __new__(cls, clsname, bases, attrs): # type: ignore
+ def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = {
@@ -433,25 +443,25 @@ class Tokenizer(metaclass=_Tokenizer):
"#": TokenType.HASH,
}
- QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
-
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
-
- HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
-
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
-
+ HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
-
- STRING_ESCAPES = ["'"]
-
- _STRING_ESCAPES: t.Set[str] = set()
-
IDENTIFIER_ESCAPES = ['"']
+ QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
+ STRING_ESCAPES = ["'"]
+ VAR_SINGLE_TOKENS: t.Set[str] = set()
+ _COMMENTS: t.Dict[str, str] = {}
+ _BIT_STRINGS: t.Dict[str, str] = {}
+ _BYTE_STRINGS: t.Dict[str, str] = {}
+ _HEX_STRINGS: t.Dict[str, str] = {}
+ _IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
+ _QUOTES: t.Dict[str, str] = {}
+ _STRING_ESCAPES: t.Set[str] = set()
- KEYWORDS = {
+ KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
"{{+": TokenType.BLOCK_START,
@@ -553,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
+ "KEEP": TokenType.KEEP,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
@@ -565,6 +576,7 @@ class Tokenizer(metaclass=_Tokenizer):
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
+ "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
"NOTNULL": TokenType.NOTNULL,
@@ -632,6 +644,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
+ "UUID": TokenType.UUID,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@@ -661,6 +674,8 @@ class Tokenizer(metaclass=_Tokenizer):
"INT8": TokenType.BIGINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
+ "BIGDECIMAL": TokenType.BIGDECIMAL,
+ "BIGNUMERIC": TokenType.BIGDECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
"NUMBER": TokenType.DECIMAL,
@@ -742,7 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
- KEYWORD_TRIE = None # autofilled
+ KEYWORD_TRIE: t.Dict = {} # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
@@ -776,19 +791,28 @@ class Tokenizer(metaclass=_Tokenizer):
self._col = 1
self._comments: t.List[str] = []
- self._char = None
- self._end = None
- self._peek = None
+ self._char = ""
+ self._end = False
+ self._peek = ""
self._prev_token_line = -1
self._prev_token_comments: t.List[str] = []
- self._prev_token_type = None
+ self._prev_token_type: t.Optional[TokenType] = None
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset()
self.sql = sql
self.size = len(sql)
- self._scan()
+ try:
+ self._scan()
+ except Exception as e:
+ start = self._current - 50
+ end = self._current + 50
+ start = start if start > 0 else 0
+ end = end if end < self.size else self.size - 1
+ context = self.sql[start:end]
+ raise ValueError(f"Error tokenizing '{context}'") from e
+
return self.tokens
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
@@ -810,9 +834,12 @@ class Tokenizer(metaclass=_Tokenizer):
if until and until():
break
+ if self.tokens:
+ self.tokens[-1].comments.extend(self._comments)
+
def _chars(self, size: int) -> str:
if size == 1:
- return self._char # type: ignore
+ return self._char
start = self._current - 1
end = start + size
if end <= self.size:
@@ -821,17 +848,15 @@ class Tokenizer(metaclass=_Tokenizer):
def _advance(self, i: int = 1) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
- self._set_new_line()
+ self._col = 1
+ self._line += 1
+ else:
+ self._col += i
- self._col += i
self._current += i
- self._end = self._current >= self.size # type: ignore
- self._char = self.sql[self._current - 1] # type: ignore
- self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
-
- def _set_new_line(self) -> None:
- self._col = 1
- self._line += 1
+ self._end = self._current >= self.size
+ self._char = self.sql[self._current - 1]
+ self._peek = "" if self._end else self.sql[self._current]
@property
def _text(self) -> str:
@@ -840,13 +865,14 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comments = self._comments
- self._prev_token_type = token_type # type: ignore
+ self._prev_token_type = token_type
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
+ self._current,
self._comments,
)
)
@@ -881,7 +907,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip:
result = 1
else:
- result, trie = in_trie(trie, char.upper()) # type: ignore
+ result, trie = in_trie(trie, char.upper())
if result == 0:
break
@@ -910,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
if not word:
if self._char in self.SINGLE_TOKENS:
- self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore
+ self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return
self._scan_var()
return
@@ -927,29 +953,31 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(self.KEYWORDS[word], text=word)
def _scan_comment(self, comment_start: str) -> bool:
- if comment_start not in self._COMMENTS: # type: ignore
+ if comment_start not in self._COMMENTS:
return False
comment_start_line = self._line
comment_start_size = len(comment_start)
- comment_end = self._COMMENTS[comment_start] # type: ignore
+ comment_end = self._COMMENTS[comment_start]
if comment_end:
- comment_end_size = len(comment_end)
+ # Skip the comment's start delimiter
+ self._advance(comment_start_size)
+ comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
- self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
+ self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
- self._comments.append(self._text[comment_start_size:]) # type: ignore
+ self._comments.append(self._text[comment_start_size:])
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
- if comment_start_line == self._prev_token_line or self._end:
+ if comment_start_line == self._prev_token_line:
self.tokens[-1].comments.extend(self._comments)
self._comments = []
self._prev_token_line = self._line
@@ -958,7 +986,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan_number(self) -> None:
if self._char == "0":
- peek = self._peek.upper() # type: ignore
+ peek = self._peek.upper()
if peek == "B":
return self._scan_bits()
elif peek == "X":
@@ -968,7 +996,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0
while True:
- if self._peek.isdigit(): # type: ignore
+ if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
decimal = True
@@ -976,24 +1004,23 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
- elif self._peek.upper() == "E" and not scientific: # type: ignore
+ elif self._peek.upper() == "E" and not scientific:
scientific += 1
self._advance()
- elif self._peek.isidentifier(): # type: ignore
+ elif self._peek.isidentifier():
number_text = self._text
- literal = []
+ literal = ""
- while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
- literal.append(self._peek.upper()) # type: ignore
+ while self._peek.strip() and self._peek not in self.SINGLE_TOKENS:
+ literal += self._peek.upper()
self._advance()
- literal = "".join(literal) # type: ignore
- token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
- return self._add(token_type, literal) # type: ignore
+ return self._add(token_type, literal)
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
@@ -1020,7 +1047,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _extract_value(self) -> str:
while True:
- char = self._peek.strip() # type: ignore
+ char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@@ -1029,35 +1056,35 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
def _scan_string(self, quote: str) -> bool:
- quote_end = self._QUOTES.get(quote) # type: ignore
+ quote_end = self._QUOTES.get(quote)
if quote_end is None:
return False
self._advance(len(quote))
text = self._extract_string(quote_end)
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
+ text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
- if string_start in self._HEX_STRINGS: # type: ignore
- delimiters = self._HEX_STRINGS # type: ignore
+ if string_start in self._HEX_STRINGS:
+ delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
base = 16
- elif string_start in self._BIT_STRINGS: # type: ignore
- delimiters = self._BIT_STRINGS # type: ignore
+ elif string_start in self._BIT_STRINGS:
+ delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
- elif string_start in self._BYTE_STRINGS: # type: ignore
- delimiters = self._BYTE_STRINGS # type: ignore
+ elif string_start in self._BYTE_STRINGS:
+ delimiters = self._BYTE_STRINGS
token_type = TokenType.BYTE_STRING
base = None
else:
return False
self._advance(len(string_start))
- string_end = delimiters.get(string_start)
+ string_end = delimiters[string_start]
text = self._extract_string(string_end)
if base is None:
@@ -1083,20 +1110,20 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
- text += identifier_end # type: ignore
+ text += identifier_end
self._advance()
continue
break
- text += self._char # type: ignore
+ text += self._char
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
- char = self._peek.strip() # type: ignore
- if char and char not in self.SINGLE_TOKENS:
+ char = self._peek.strip()
+ if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
self._advance()
else:
break
@@ -1115,9 +1142,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._peek == delimiter or self._peek in self._STRING_ESCAPES
):
if self._peek == delimiter:
- text += self._peek # type: ignore
+ text += self._peek
else:
- text += self._char + self._peek # type: ignore
+ text += self._char + self._peek
if self._current + 1 < self.size:
self._advance(2)
@@ -1131,7 +1158,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
- text += self._char # type: ignore
+ text += self._char
self._advance()
return text