summaryrefslogtreecommitdiffstats
path: root/sqlglot/tokens.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/tokens.py91
1 files changed, 87 insertions, 4 deletions
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index e4c3204..de9d4c4 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -1,9 +1,10 @@
from __future__ import annotations
+import os
import typing as t
from enum import auto
-from sqlglot.errors import TokenError
+from sqlglot.errors import SqlglotError, TokenError
from sqlglot.helper import AutoName
from sqlglot.trie import TrieResult, in_trie, new_trie
@@ -11,6 +12,19 @@ if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
+try:
+ from sqlglotrs import ( # type: ignore
+ Tokenizer as RsTokenizer,
+ TokenizerDialectSettings as RsTokenizerDialectSettings,
+ TokenizerSettings as RsTokenizerSettings,
+ TokenTypeSettings as RsTokenTypeSettings,
+ )
+
+ USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1"
+except ImportError:
+ USE_RS_TOKENIZER = False
+
+
class TokenType(AutoName):
L_PAREN = auto()
R_PAREN = auto()
@@ -83,6 +97,7 @@ class TokenType(AutoName):
NATIONAL_STRING = auto()
RAW_STRING = auto()
HEREDOC_STRING = auto()
+ UNICODE_STRING = auto()
# types
BIT = auto()
@@ -347,6 +362,10 @@ class TokenType(AutoName):
TIMESTAMP_SNAPSHOT = auto()
+_ALL_TOKEN_TYPES = list(TokenType)
+_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)}
+
+
class Token:
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
@@ -432,6 +451,7 @@ class _Tokenizer(type):
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
+ **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS),
}
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
@@ -455,6 +475,46 @@ class _Tokenizer(type):
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
+ if USE_RS_TOKENIZER:
+ settings = RsTokenizerSettings(
+ white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()},
+ single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()},
+ keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()},
+ numeric_literals=klass.NUMERIC_LITERALS,
+ identifiers=klass._IDENTIFIERS,
+ identifier_escapes=klass._IDENTIFIER_ESCAPES,
+ string_escapes=klass._STRING_ESCAPES,
+ quotes=klass._QUOTES,
+ format_strings={
+ k: (v1, _TOKEN_TYPE_TO_INDEX[v2])
+ for k, (v1, v2) in klass._FORMAT_STRINGS.items()
+ },
+ has_bit_strings=bool(klass.BIT_STRINGS),
+ has_hex_strings=bool(klass.HEX_STRINGS),
+ comments=klass._COMMENTS,
+ var_single_tokens=klass.VAR_SINGLE_TOKENS,
+ commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS},
+ command_prefix_tokens={
+ _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
+ },
+ )
+ token_types = RsTokenTypeSettings(
+ bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
+ break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
+ dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
+ heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
+ hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
+ identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
+ number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
+ parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER],
+ semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
+ string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
+ var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
+ )
+ klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
+ else:
+ klass._RS_TOKENIZER = None
+
return klass
@@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer):
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
+ UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
@@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer):
_QUOTES: t.Dict[str, str] = {}
_STRING_ESCAPES: t.Set[str] = set()
_KEYWORD_TRIE: t.Dict = {}
+ _RS_TOKENIZER: t.Optional[t.Any] = None
KEYWORDS: t.Dict[str, TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
@@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
# handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}
- ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
@@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
+ "_rs_dialect_settings",
)
def __init__(self, dialect: DialectType = None) -> None:
from sqlglot.dialects import Dialect
self.dialect = Dialect.get_or_raise(dialect)
+
+ if USE_RS_TOKENIZER:
+ self._rs_dialect_settings = RsTokenizerDialectSettings(
+ escape_sequences=self.dialect.ESCAPE_SEQUENCES,
+ identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
+ )
+
self.reset()
def reset(self) -> None:
@@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer):
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
+ if USE_RS_TOKENIZER:
+ return self.tokenize_rs(sql)
+
self.reset()
self.sql = sql
self.size = len(sql)
@@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer):
# Ensures we don't count an extra line if we get a \r\n line break sequence
if self._char == "\r" and self._peek == "\n":
i = 2
+ self._start += 1
self._col = 1
self._line += 1
@@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer):
raise TokenError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
- else:
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(token_type, text)
return True
@@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer):
text += self.sql[current : self._current - 1]
return text
+
+ def tokenize_rs(self, sql: str) -> t.List[Token]:
+ if not self._RS_TOKENIZER:
+ raise SqlglotError("Rust tokenizer is not available")
+
+ try:
+ tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings)
+ for token in tokens:
+ token.token_type = _ALL_TOKEN_TYPES[token.token_type_index]
+ return tokens
+ except Exception as e:
+ raise TokenError(str(e))