diff options
Diffstat (limited to 'sqlglot/tokens.py')
-rw-r--r-- | sqlglot/tokens.py | 91 |
1 files changed, 87 insertions, 4 deletions
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index e4c3204..de9d4c4 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1,9 +1,10 @@ from __future__ import annotations +import os import typing as t from enum import auto -from sqlglot.errors import TokenError +from sqlglot.errors import SqlglotError, TokenError from sqlglot.helper import AutoName from sqlglot.trie import TrieResult, in_trie, new_trie @@ -11,6 +12,19 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType +try: + from sqlglotrs import ( # type: ignore + Tokenizer as RsTokenizer, + TokenizerDialectSettings as RsTokenizerDialectSettings, + TokenizerSettings as RsTokenizerSettings, + TokenTypeSettings as RsTokenTypeSettings, + ) + + USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1" +except ImportError: + USE_RS_TOKENIZER = False + + class TokenType(AutoName): L_PAREN = auto() R_PAREN = auto() @@ -83,6 +97,7 @@ class TokenType(AutoName): NATIONAL_STRING = auto() RAW_STRING = auto() HEREDOC_STRING = auto() + UNICODE_STRING = auto() # types BIT = auto() @@ -347,6 +362,10 @@ class TokenType(AutoName): TIMESTAMP_SNAPSHOT = auto() +_ALL_TOKEN_TYPES = list(TokenType) +_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)} + + class Token: __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments") @@ -432,6 +451,7 @@ class _Tokenizer(type): **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS), **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS), **_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS), + **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS), } klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) @@ -455,6 +475,46 @@ class _Tokenizer(type): if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) + if USE_RS_TOKENIZER: + settings = RsTokenizerSettings( + white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()}, + single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()}, + keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()}, + numeric_literals=klass.NUMERIC_LITERALS, + identifiers=klass._IDENTIFIERS, + identifier_escapes=klass._IDENTIFIER_ESCAPES, + string_escapes=klass._STRING_ESCAPES, + quotes=klass._QUOTES, + format_strings={ + k: (v1, _TOKEN_TYPE_TO_INDEX[v2]) + for k, (v1, v2) in klass._FORMAT_STRINGS.items() + }, + has_bit_strings=bool(klass.BIT_STRINGS), + has_hex_strings=bool(klass.HEX_STRINGS), + comments=klass._COMMENTS, + var_single_tokens=klass.VAR_SINGLE_TOKENS, + commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS}, + command_prefix_tokens={ + _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS + }, + ) + token_types = RsTokenTypeSettings( + bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING], + break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK], + dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON], + heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING], + hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING], + identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER], + number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER], + parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER], + semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON], + string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING], + var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR], + ) + klass._RS_TOKENIZER = RsTokenizer(settings, token_types) + else: + klass._RS_TOKENIZER = None + return klass @@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer): HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] RAW_STRINGS: t.List[str | t.Tuple[str, str]] = [] HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = [] + UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = [] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] IDENTIFIER_ESCAPES = ['"'] QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] @@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer): _QUOTES: t.Dict[str, str] = {} _STRING_ESCAPES: t.Set[str] = set() _KEYWORD_TRIE: t.Dict = {} + _RS_TOKENIZER: t.Optional[t.Any] = None KEYWORDS: t.Dict[str, TokenType] = { **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, @@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer): # handle numeric literals like in hive (3L = BIGINT) NUMERIC_LITERALS: t.Dict[str, str] = {} - ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/")] @@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer): "_end", "_peek", "_prev_token_line", + "_rs_dialect_settings", ) def __init__(self, dialect: DialectType = None) -> None: from sqlglot.dialects import Dialect self.dialect = Dialect.get_or_raise(dialect) + + if USE_RS_TOKENIZER: + self._rs_dialect_settings = RsTokenizerDialectSettings( + escape_sequences=self.dialect.ESCAPE_SEQUENCES, + identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT, + ) + self.reset() def reset(self) -> None: @@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer): def tokenize(self, sql: str) -> t.List[Token]: """Returns a list of tokens corresponding to the SQL string `sql`.""" + if USE_RS_TOKENIZER: + return self.tokenize_rs(sql) + self.reset() self.sql = sql self.size = len(sql) @@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer): # Ensures we don't count an extra line if we get a \r\n line break sequence if self._char == "\r" and self._peek == "\n": i = 2 + self._start += 1 self._col = 1 self._line += 1 @@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer): raise TokenError( f"Numeric string contains invalid characters from {self._line}:{self._start}" ) - else: - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text self._add(token_type, text) return True @@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer): text += self.sql[current : self._current - 1] return text + + def tokenize_rs(self, sql: str) -> t.List[Token]: + if not self._RS_TOKENIZER: + raise SqlglotError("Rust tokenizer is not available") + + try: + tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings) + for token in tokens: + token.token_type = _ALL_TOKEN_TYPES[token.token_type_index] + return tokens + except Exception as e: + raise TokenError(str(e)) |