diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-12-19 11:01:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-12-19 11:01:55 +0000 |
commit | f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 (patch) | |
tree | 5dce0fe2a11381761496eb973c20750f44db56d5 /sqlglot | |
parent | Releasing debian version 20.1.0-1. (diff) | |
download | sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.tar.xz sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.zip |
Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/__init__.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 101 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 37 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 41 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 80 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 24 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 7 | ||||
-rw-r--r-- | sqlglot/executor/python.py | 3 | ||||
-rw-r--r-- | sqlglot/expressions.py | 46 | ||||
-rw-r--r-- | sqlglot/generator.py | 67 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 33 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 17 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 13 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 39 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 153 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/parser.py | 129 | ||||
-rw-r--r-- | sqlglot/planner.py | 11 | ||||
-rw-r--r-- | sqlglot/tokens.py | 91 |
25 files changed, 642 insertions, 289 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 8212669..04990ac 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -12,7 +12,7 @@ classes as needed. ### Implementing a custom Dialect -Consider the following example: +Creating a new SQL dialect may seem complicated at first, but it is actually quite simple in SQLGlot: ```python from sqlglot import exp @@ -23,9 +23,10 @@ from sqlglot.tokens import Tokenizer, TokenType class Custom(Dialect): class Tokenizer(Tokenizer): - QUOTES = ["'", '"'] - IDENTIFIERS = ["`"] + QUOTES = ["'", '"'] # Strings can be delimited by either single or double quotes + IDENTIFIERS = ["`"] # Identifiers can be delimited by backticks + # Associates certain meaningful words with tokens that capture their intent KEYWORDS = { **Tokenizer.KEYWORDS, "INT64": TokenType.BIGINT, @@ -33,8 +34,12 @@ class Custom(Dialect): } class Generator(Generator): - TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"} + # Specifies how AST nodes, i.e. subclasses of exp.Expression, should be converted into SQL + TRANSFORMS = { + exp.Array: lambda self, e: f"[{self.expressions(e)}]", + } + # Specifies how AST nodes representing data types should be converted into SQL TYPE_MAPPING = { exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.SMALLINT: "INT64", @@ -48,10 +53,9 @@ class Custom(Dialect): } ``` -This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string -delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since -the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation -logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping. +The above example demonstrates how certain parts of the base `Dialect` class can be overridden to match a different +specification. Even though it is a fairly realistic starting point, we strongly encourage the reader to study existing +dialect implementations in order to understand how their various components can be modified, depending on the use-case. ---- """ diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 2a9dde9..1b06cbf 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -215,6 +215,7 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s class BigQuery(Dialect): + WEEK_OFFSET = -1 UNNEST_COLUMN_ONLY = True SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_SEMI_ANTI_JOIN = False @@ -437,11 +438,7 @@ class BigQuery(Dialect): elif isinstance(this, exp.Literal): table_name = this.name - if ( - self._curr - and self._prev.end == self._curr.start - 1 - and self._parse_var(any_token=True) - ): + if self._is_connected() and self._parse_var(any_token=True): table_name += self._prev.text this = exp.Identifier(this=table_name, quoted=True) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index da182aa..7a3f897 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -83,6 +83,11 @@ class ClickHouse(Dialect): } class Parser(parser.Parser): + # Tested in ClickHouse's playground, it seems that the following two queries do the same thing + # * select x from t1 union all select x from t2 limit 1; + # * select x from t1 union all (select x from t2 limit 1); + MODIFIERS_ATTACHED_TO_UNION = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index c7cea64..b7eef45 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -21,11 +21,14 @@ DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] class Dialects(str, Enum): + """Dialects supported by SQLGLot.""" + DIALECT = "" BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" DATABRICKS = "databricks" + DORIS = "doris" DRILL = "drill" DUCKDB = "duckdb" HIVE = "hive" @@ -43,16 +46,22 @@ class Dialects(str, Enum): TERADATA = "teradata" TRINO = "trino" TSQL = "tsql" - Doris = "doris" class NormalizationStrategy(str, AutoName): """Specifies the strategy according to which identifiers should be normalized.""" - LOWERCASE = auto() # Unquoted identifiers are lowercased - UPPERCASE = auto() # Unquoted identifiers are uppercased - CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes - CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes + LOWERCASE = auto() + """Unquoted identifiers are lowercased.""" + + UPPERCASE = auto() + """Unquoted identifiers are uppercased.""" + + CASE_SENSITIVE = auto() + """Always case-sensitive, regardless of quotes.""" + + CASE_INSENSITIVE = auto() + """Always case-insensitive, regardless of quotes.""" class _Dialect(type): @@ -117,6 +126,7 @@ class _Dialect(type): klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) + klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) if enum not in ("", "bigquery"): klass.generator_class.SELECT_KINDS = () @@ -131,74 +141,84 @@ class _Dialect(type): class Dialect(metaclass=_Dialect): - # Determines the base index offset for arrays INDEX_OFFSET = 0 + """Determines the base index offset for arrays.""" + + WEEK_OFFSET = 0 + """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" - # If true unnest table aliases are considered only as column aliases UNNEST_COLUMN_ONLY = False + """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" - # Determines whether or not the table alias comes after tablesample ALIAS_POST_TABLESAMPLE = False + """Determines whether or not the table alias comes after tablesample.""" - # Specifies the strategy according to which identifiers should be normalized. NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE + """Specifies the strategy according to which identifiers should be normalized.""" - # Determines whether or not an unquoted identifier can start with a digit IDENTIFIERS_CAN_START_WITH_DIGIT = False + """Determines whether or not an unquoted identifier can start with a digit.""" - # Determines whether or not the DPIPE token ('||') is a string concatenation operator DPIPE_IS_STRING_CONCAT = True + """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" - # Determines whether or not CONCAT's arguments must be strings STRICT_STRING_CONCAT = False + """Determines whether or not `CONCAT`'s arguments must be strings.""" - # Determines whether or not user-defined data types are supported SUPPORTS_USER_DEFINED_TYPES = True + """Determines whether or not user-defined data types are supported.""" - # Determines whether or not SEMI/ANTI JOINs are supported SUPPORTS_SEMI_ANTI_JOIN = True + """Determines whether or not `SEMI` or `ANTI` joins are supported.""" - # Determines how function names are going to be normalized NORMALIZE_FUNCTIONS: bool | str = "upper" + """Determines how function names are going to be normalized.""" - # Determines whether the base comes first in the LOG function LOG_BASE_FIRST = True + """Determines whether the base comes first in the `LOG` function.""" - # Indicates the default null ordering method to use if not explicitly set - # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" NULL_ORDERING = "nulls_are_small" + """ + Indicates the default `NULL` ordering method to use if not explicitly set. + Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` + """ - # Whether the behavior of a / b depends on the types of a and b. - # False means a / b is always float division. - # True means a / b is integer division if both a and b are integers. TYPED_DIVISION = False + """ + Whether the behavior of `a / b` depends on the types of `a` and `b`. + False means `a / b` is always float division. + True means `a / b` is integer division if both `a` and `b` are integers. + """ - # False means 1 / 0 throws an error. - # True means 1 / 0 returns null. SAFE_DIVISION = False + """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" - # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string CONCAT_COALESCE = False + """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" DATE_FORMAT = "'%Y-%m-%d'" DATEINT_FORMAT = "'%Y%m%d'" TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" - # Custom time mappings in which the key represents dialect time format - # and the value represents a python time format TIME_MAPPING: t.Dict[str, str] = {} + """Associates this dialect's time formats with their equivalent Python `strftime` format.""" # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE - # special syntax cast(x as date format 'yyyy') defaults to time_mapping FORMAT_MAPPING: t.Dict[str, str] = {} + """ + Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. + If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. + """ - # Mapping of an unescaped escape sequence to the corresponding character ESCAPE_SEQUENCES: t.Dict[str, str] = {} + """Mapping of an unescaped escape sequence to the corresponding character.""" - # Columns that are auto-generated by the engine corresponding to this dialect - # Such columns may be excluded from SELECT * queries, for example PSEUDOCOLUMNS: t.Set[str] = set() + """ + Columns that are auto-generated by the engine corresponding to this dialect. + For example, such columns may be excluded from `SELECT *` queries. + """ # --- Autofilled --- @@ -221,13 +241,15 @@ class Dialect(metaclass=_Dialect): IDENTIFIER_START = '"' IDENTIFIER_END = '"' - # Delimiters for bit, hex and byte literals + # Delimiters for bit, hex, byte and unicode literals BIT_START: t.Optional[str] = None BIT_END: t.Optional[str] = None HEX_START: t.Optional[str] = None HEX_END: t.Optional[str] = None BYTE_START: t.Optional[str] = None BYTE_END: t.Optional[str] = None + UNICODE_START: t.Optional[str] = None + UNICODE_END: t.Optional[str] = None @classmethod def get_or_raise(cls, dialect: DialectType) -> Dialect: @@ -275,6 +297,7 @@ class Dialect(metaclass=_Dialect): def format_time( cls, expression: t.Optional[str | exp.Expression] ) -> t.Optional[exp.Expression]: + """Converts a time format in this dialect to its equivalent Python `strftime` format.""" if isinstance(expression, str): return exp.Literal.string( # the time formats are quoted @@ -306,9 +329,9 @@ class Dialect(metaclass=_Dialect): """ Transforms an identifier in a way that resembles how it'd be resolved by this dialect. - For example, an identifier like FoO would be resolved as foo in Postgres, because it + For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so - it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, + it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier. There are also dialects like Spark, which are case-insensitive even when quotes are @@ -356,8 +379,8 @@ class Dialect(metaclass=_Dialect): Args: text: The text to check. identify: - "always" or `True`: Always returns true. - "safe": True if the identifier is case-insensitive. + `"always"` or `True`: Always returns `True`. + `"safe"`: Only returns `True` if the identifier is case-insensitive. Returns: Whether or not the given text can be identified. @@ -371,6 +394,14 @@ class Dialect(metaclass=_Dialect): return False def quote_identifier(self, expression: E, identify: bool = True) -> E: + """ + Adds quotes to a given identifier. + + Args: + expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. + identify: If set to `False`, the quotes will only be added if the identifier is deemed + "unsafe", with respect to its characters and this dialect's normalization strategy. + """ if isinstance(expression, exp.Identifier): name = expression.this expression.set( diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 70c96f8..c9b31a0 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -81,7 +81,6 @@ class Drill(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] - ENCODE = "utf-8" class Parser(parser.Parser): STRICT_CAST = False diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index b94e3a6..41afad8 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -84,11 +84,35 @@ def _parse_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) +def _parse_make_timestamp(args: t.List) -> exp.Expression: + if len(args) == 1: + return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS) + + return exp.TimestampFromParts( + year=seq_get(args, 0), + month=seq_get(args, 1), + day=seq_get(args, 2), + hour=seq_get(args, 3), + min=seq_get(args, 4), + sec=seq_get(args, 5), + ) + + def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: - args = [ - f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}" - for e in expression.expressions - ] + args: t.List[str] = [] + for expr in expression.expressions: + if isinstance(expr, exp.Alias): + key = expr.alias + value = expr.this + else: + key = expr.name or expr.this.name + if isinstance(expr, exp.Bracket): + value = expr.expressions[0] + else: + value = expr.expression + + args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}") + return f"{{{', '.join(args)}}}" @@ -189,9 +213,7 @@ class DuckDB(Dialect): "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, - "MAKE_TIMESTAMP": lambda args: exp.UnixToTime( - this=seq_get(args, 0), scale=exp.UnixToTime.MICROS - ), + "MAKE_TIMESTAMP": _parse_make_timestamp, "MEDIAN": lambda args: exp.PercentileCont( this=seq_get(args, 0), expression=exp.Literal.number(0.5) ), @@ -339,6 +361,7 @@ class DuckDB(Dialect): exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, exp.Timestamp: no_timestamp_sql, + exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 0723e37..65c85bb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -240,7 +240,6 @@ class Hive(Dialect): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] - ENCODE = "utf-8" SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index cfc6e83..5fe3d82 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -650,7 +650,7 @@ class MySQL(Dialect): exp.Min: min_or_least, exp.Month: _remove_ts_or_ds_to_date(), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), - exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}", exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( [ diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index fefddee..bf65edf 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -277,6 +277,7 @@ class Postgres(Dialect): "CONSTRAINT TRIGGER": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, + "EXEC": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, "JSONB": TokenType.JSONB, "MONEY": TokenType.MONEY, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 10a6074..360ab65 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -186,6 +186,27 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str return "" +def _to_int(expression: exp.Expression) -> exp.Expression: + if not expression.type: + from sqlglot.optimizer.annotate_types import annotate_types + + annotate_types(expression) + if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES: + return exp.cast(expression, to=exp.DataType.Type.BIGINT) + return expression + + +def _parse_to_char(args: t.List) -> exp.TimeToStr: + fmt = seq_get(args, 1) + if isinstance(fmt, exp.Literal): + # We uppercase this to match Teradata's format mapping keys + fmt.set("this", fmt.this.upper()) + + # We use "teradata" on purpose here, because the time formats are different in Presto. + # See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char + return format_time_lambda(exp.TimeToStr, "teradata")(args) + + class Presto(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_last" @@ -201,6 +222,12 @@ class Presto(Dialect): NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE class Tokenizer(tokens.Tokenizer): + UNICODE_STRINGS = [ + (prefix + q, q) + for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES) + for prefix in ("U&", "u&") + ] + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "START": TokenType.BEGIN, @@ -253,8 +280,9 @@ class Presto(Dialect): "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), - "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, + "TO_CHAR": _parse_to_char, "TO_HEX": exp.Hex.from_arg_list, + "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_UTF8": lambda args: exp.Encode( this=seq_get(args, 0), charset=exp.Literal.string("utf-8") ), @@ -315,7 +343,12 @@ class Presto(Dialect): exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + "DATE_ADD", + exp.Literal.string(e.text("unit") or "day"), + _to_int( + e.expression, + ), + e.this, ), exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this @@ -325,7 +358,7 @@ class Presto(Dialect): exp.DateSub: lambda self, e: self.func( "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), - e.expression * -1, + _to_int(e.expression * -1), e.this, ), exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), @@ -354,6 +387,7 @@ class Presto(Dialect): exp.Right: right_to_substring_sql, exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, + exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.Select: transforms.preprocess( [ transforms.eliminate_qualify, @@ -377,6 +411,7 @@ class Presto(Dialect): exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), + exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index cdbc071..f09a990 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -293,7 +293,6 @@ class Snowflake(Dialect): "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, - "TO_ARRAY": exp.Array.from_arg_list, "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, @@ -369,36 +368,58 @@ class Snowflake(Dialect): return lateral + def _parse_at_before(self, table: exp.Table) -> exp.Table: + # https://docs.snowflake.com/en/sql-reference/constructs/at-before + index = self._index + if self._match_texts(("AT", "BEFORE")): + this = self._prev.text.upper() + kind = ( + self._match(TokenType.L_PAREN) + and self._match_texts(self.HISTORICAL_DATA_KIND) + and self._prev.text.upper() + ) + expression = self._match(TokenType.FARROW) and self._parse_bitwise() + + if expression: + self._match_r_paren() + when = self.expression( + exp.HistoricalData, this=this, kind=kind, expression=expression + ) + table.set("when", when) + else: + self._retreat(index) + + return table + def _parse_table_parts(self, schema: bool = False) -> exp.Table: # https://docs.snowflake.com/en/user-guide/querying-stage - table: t.Optional[exp.Expression] = None - if self._match_text_seq("@"): - table_name = "@" - while self._curr: - self._advance() - table_name += self._prev.text - if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False): - break - while self._match_set(self.STAGED_FILE_SINGLE_TOKENS): - table_name += self._prev.text - - table = exp.var(table_name) - elif self._match(TokenType.STRING, advance=False): + if self._match(TokenType.STRING, advance=False): table = self._parse_string() + elif self._match_text_seq("@", advance=False): + table = self._parse_location_path() + else: + table = None if table: file_format = None pattern = None - if self._match_text_seq("(", "FILE_FORMAT", "=>"): - file_format = self._parse_string() or super()._parse_table_parts() - if self._match_text_seq(",", "PATTERN", "=>"): + self._match(TokenType.L_PAREN) + while self._curr and not self._match(TokenType.R_PAREN): + if self._match_text_seq("FILE_FORMAT", "=>"): + file_format = self._parse_string() or super()._parse_table_parts() + elif self._match_text_seq("PATTERN", "=>"): pattern = self._parse_string() - self._match_r_paren() + else: + break + + self._match(TokenType.COMMA) - return self.expression(exp.Table, this=table, format=file_format, pattern=pattern) + table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern) + else: + table = super()._parse_table_parts(schema=schema) - return super()._parse_table_parts(schema=schema) + return self._parse_at_before(table) def _parse_id_var( self, @@ -438,17 +459,17 @@ class Snowflake(Dialect): def _parse_location(self) -> exp.LocationProperty: self._match(TokenType.EQ) + return self.expression(exp.LocationProperty, this=self._parse_location_path()) - parts = [self._parse_var(any_token=True)] + def _parse_location_path(self) -> exp.Var: + parts = [self._advance_any(ignore_reserved=True)] - while self._match(TokenType.SLASH): - if self._curr and self._prev.end + 1 == self._curr.start: - parts.append(self._parse_var(any_token=True)) - else: - parts.append(exp.Var(this="")) - return self.expression( - exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts)) - ) + # We avoid consuming a comma token because external tables like @foo and @bar + # can be joined in a query with a comma separator. + while self._is_connected() and not self._match(TokenType.COMMA, advance=False): + parts.append(self._advance_any(ignore_reserved=True)) + + return exp.var("".join(part.text for part in parts if part)) class Tokenizer(tokens.Tokenizer): STRING_ESCAPES = ["\\", "'"] @@ -562,6 +583,7 @@ class Snowflake(Dialect): "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) ), exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", + exp.ToArray: rename_func("TO_ARRAY"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 141d9c0..0ccc567 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -12,22 +12,30 @@ class Teradata(Dialect): TYPED_DIVISION = True TIME_MAPPING = { - "Y": "%Y", - "YYYY": "%Y", "YY": "%y", - "MMMM": "%B", + "Y4": "%Y", + "YYYY": "%Y", + "M4": "%B", + "M3": "%b", + "M": "%-M", + "MI": "%M", + "MM": "%m", "MMM": "%b", - "DD": "%d", + "MMMM": "%B", "D": "%-d", - "HH": "%H", + "DD": "%d", + "D3": "%j", + "DDD": "%j", "H": "%-H", - "MM": "%M", - "M": "%-M", - "SS": "%S", + "HH": "%H", + "HH24": "%H", "S": "%-S", + "SS": "%S", "SSSSSS": "%f", "E": "%a", "EE": "%a", + "E3": "%a", + "E4": "%A", "EEE": "%a", "EEEE": "%A", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index c3d4f0a..165a703 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -701,6 +701,13 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def set_operation(self, expression: exp.Union, op: str) -> str: + limit = expression.args.get("limit") + if limit: + return self.sql(expression.limit(limit.pop(), copy=False)) + + return super().set_operation(expression, op) + def setitem_sql(self, expression: exp.SetItem) -> str: this = expression.this if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter): diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index e1e597d..3277e65 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -343,6 +343,9 @@ class PythonExecutor: else: sink.rows = left.rows + right.rows + if not math.isinf(step.limit): + sink.rows = sink.rows[0 : step.limit] + return self.context({step.name: sink}) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 99722be..8246769 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1105,14 +1105,7 @@ class Create(DDL): # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy class Clone(Expression): - arg_types = { - "this": True, - "when": False, - "kind": False, - "shallow": False, - "expression": False, - "copy": False, - } + arg_types = {"this": True, "shallow": False, "copy": False} class Describe(Expression): @@ -1213,6 +1206,10 @@ class RawString(Condition): pass +class UnicodeString(Condition): + arg_types = {"this": True, "escape": False} + + class Column(Condition): arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False} @@ -1967,7 +1964,12 @@ class Offset(Expression): class Order(Expression): - arg_types = {"this": False, "expressions": True} + arg_types = {"this": False, "expressions": True, "interpolate": False} + + +# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier +class WithFill(Expression): + arg_types = {"from": False, "to": False, "step": False} # hive specific sorts @@ -1985,7 +1987,7 @@ class Sort(Order): class Ordered(Expression): - arg_types = {"this": True, "desc": False, "nulls_first": True} + arg_types = {"this": True, "desc": False, "nulls_first": True, "with_fill": False} class Property(Expression): @@ -2522,6 +2524,11 @@ class IndexTableHint(Expression): arg_types = {"this": True, "expressions": False, "target": False} +# https://docs.snowflake.com/en/sql-reference/constructs/at-before +class HistoricalData(Expression): + arg_types = {"this": True, "kind": True, "expression": True} + + class Table(Expression): arg_types = { "this": True, @@ -2538,6 +2545,7 @@ class Table(Expression): "pattern": False, "index": False, "ordinality": False, + "when": False, } @property @@ -4310,6 +4318,11 @@ class Array(Func): is_var_len_args = True +# https://docs.snowflake.com/en/sql-reference/functions/to_array +class ToArray(Func): + pass + + # https://docs.snowflake.com/en/sql-reference/functions/to_char # https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html class ToChar(Func): @@ -5233,6 +5246,19 @@ class UnixToTimeStr(Func): pass +class TimestampFromParts(Func): + """Constructs a timestamp given its constituent parts.""" + + arg_types = { + "year": True, + "month": True, + "day": True, + "hour": True, + "min": True, + "sec": True, + } + + class Upper(Func): _sql_names = ["UPPER", "UCASE"] diff --git a/sqlglot/generator.py b/sqlglot/generator.py index f3f9060..c571e8f 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -862,15 +862,7 @@ class Generator: this = self.sql(expression, "this") shallow = "SHALLOW " if expression.args.get("shallow") else "" keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE" - this = f"{shallow}{keyword} {this}" - when = self.sql(expression, "when") - - if when: - kind = self.sql(expression, "kind") - expr = self.sql(expression, "expression") - return f"{this} {when} ({kind} => {expr})" - - return this + return f"{shallow}{keyword} {this}" def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" @@ -923,6 +915,14 @@ class Generator: return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}" return this + def unicodestring_sql(self, expression: exp.UnicodeString) -> str: + this = self.sql(expression, "this") + if self.dialect.UNICODE_START: + escape = self.sql(expression, "escape") + escape = f" UESCAPE {escape}" if escape else "" + return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}" + return this + def rawstring_sql(self, expression: exp.RawString) -> str: string = self.escape_str(expression.this.replace("\\", "\\\\")) return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}" @@ -1400,6 +1400,12 @@ class Generator: target = f" FOR {target}" if target else "" return f"{this}{target} ({self.expressions(expression, flat=True)})" + def historicaldata_sql(self, expression: exp.HistoricalData) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + expr = self.sql(expression, "expression") + return f"{this} ({kind} => {expr})" + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: table = ".".join( self.sql(part) @@ -1436,6 +1442,10 @@ class Generator: ordinality = f" WITH ORDINALITY{alias}" alias = "" + when = self.sql(expression, "when") + if when: + table = f"{table} {when}" + return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}" def tablesample_sql( @@ -1784,7 +1794,24 @@ class Generator: def order_sql(self, expression: exp.Order, flat: bool = False) -> str: this = self.sql(expression, "this") this = f"{this} " if this else this - return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore + order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore + interpolated_values = [ + f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}" + for named_expression in expression.args.get("interpolate") or [] + ] + interpolate = ( + f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else "" + ) + return f"{order}{interpolate}" + + def withfill_sql(self, expression: exp.WithFill) -> str: + from_sql = self.sql(expression, "from") + from_sql = f" FROM {from_sql}" if from_sql else "" + to_sql = self.sql(expression, "to") + to_sql = f" TO {to_sql}" if to_sql else "" + step_sql = self.sql(expression, "step") + step_sql = f" STEP {step_sql}" if step_sql else "" + return f"WITH FILL{from_sql}{to_sql}{step_sql}" def cluster_sql(self, expression: exp.Cluster) -> str: return self.op_expressions("CLUSTER BY", expression) @@ -1826,7 +1853,10 @@ class Generator: this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" nulls_sort_change = "" - return f"{this}{sort_order}{nulls_sort_change}" + with_fill = self.sql(expression, "with_fill") + with_fill = f" {with_fill}" if with_fill else "" + + return f"{this}{sort_order}{nulls_sort_change}{with_fill}" def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: partition = self.partition_by_sql(expression) @@ -3048,11 +3078,24 @@ class Generator: def operator_sql(self, expression: exp.Operator) -> str: return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})") + def toarray_sql(self, expression: exp.ToArray) -> str: + arg = expression.this + if not arg.type: + from sqlglot.optimizer.annotate_types import annotate_types + + arg = annotate_types(arg) + + if arg.is_type(exp.DataType.Type.ARRAY): + return self.sql(arg) + + cond_for_null = arg.is_(exp.null()) + return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg]))) + def _simplify_unless_literal(self, expression: E) -> E: if not isinstance(expression, exp.Literal): from sqlglot.optimizer.simplify import simplify - expression = simplify(expression) + expression = simplify(expression, dialect=self.dialect) return expression diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 1ab7768..1230cea 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -95,9 +95,6 @@ def eliminate_subqueries(expression): def _eliminate(scope, existing_ctes, taken): - if scope.is_union: - return _eliminate_union(scope, existing_ctes, taken) - if scope.is_derived_table: return _eliminate_derived_table(scope, existing_ctes, taken) @@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken): return _eliminate_cte(scope, existing_ctes, taken) -def _eliminate_union(scope, existing_ctes, taken): - duplicate_cte_alias = existing_ctes.get(scope.expression) - - alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte") - - taken[alias] = scope - - # Try to maintain the selections - expressions = scope.expression.selects - selects = [ - exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) - for e in expressions - if e.alias_or_name - ] - # If not all selections have an alias, just select * - if len(selects) != len(expressions): - selects = ["*"] - - scope.expression.replace( - exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False)) - ) - - if not duplicate_cte_alias: - existing_ctes[scope.expression] = alias - return exp.CTE( - this=scope.expression, - alias=exp.TableAlias(this=exp.to_identifier(alias)), - ) - - def _eliminate_derived_table(scope, existing_ctes, taken): # This makes sure that we don't: # - drop the "pivot" arg from a pivoted subquery diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index a74bea7..ea148cc 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): for col in inner_projections[selection].find_all(exp.Column) ) + def _is_recursive(): + # Recursive CTEs look like this: + # WITH RECURSIVE cte AS ( + # SELECT * FROM x <-- inner scope + # UNION ALL + # SELECT * FROM cte <-- outer scope + # ) + cte = inner_scope.expression.parent + node = outer_scope.expression.parent + + while node: + if node is cte: + return True + node = node.parent + return False + return ( isinstance(outer_scope.expression, exp.Select) and not outer_scope.expression.is_star @@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): ) and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() + and not _is_recursive() ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f7348b5..10ff13a 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope from sqlglot.optimizer.simplify import simplify -def pushdown_predicates(expression): +def pushdown_predicates(expression, dialect=None): """ Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS @@ -36,7 +36,7 @@ def pushdown_predicates(expression): if isinstance(parent, exp.Join) and parent.side == "RIGHT": selected_sources = {k: (node, source)} break - pushdown(where.this, selected_sources, scope_ref_count) + pushdown(where.this, selected_sources, scope_ref_count, dialect) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -44,17 +44,20 @@ def pushdown_predicates(expression): name = join.alias_or_name if name in scope.selected_sources: pushdown( - join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count + join.args.get("on"), + {name: scope.selected_sources[name]}, + scope_ref_count, + dialect, ) return expression -def pushdown(condition, sources, scope_ref_count): +def pushdown(condition, sources, scope_ref_count, dialect): if not condition: return - condition = condition.replace(simplify(condition)) + condition = condition.replace(simplify(condition, dialect=dialect)) cnf_like = normalized(condition) or not normalized(condition, dnf=True) predicates = list( diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b7e527e..d34857d 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -37,6 +37,7 @@ class Scope: For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source. + cte_sources (dict[str, Scope]): Sources from CTES outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -61,11 +62,14 @@ class Scope: parent=None, scope_type=ScopeType.ROOT, lateral_sources=None, + cte_sources=None, ): self.expression = expression self.sources = sources or {} - self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + self.lateral_sources = lateral_sources or {} + self.cte_sources = cte_sources or {} self.sources.update(self.lateral_sources) + self.sources.update(self.cte_sources) self.outer_column_list = outer_column_list or [] self.parent = parent self.scope_type = scope_type @@ -92,13 +96,17 @@ class Scope: self._pivots = None self._references = None - def branch(self, expression, scope_type, chain_sources=None, **kwargs): + def branch( + self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs + ): """Branch from the current scope to a new, inner scope""" return Scope( expression=expression.unnest(), - sources={**self.cte_sources, **(chain_sources or {})}, + sources=sources.copy() if sources else None, parent=self, scope_type=scope_type, + cte_sources={**self.cte_sources, **(cte_sources or {})}, + lateral_sources=lateral_sources.copy() if lateral_sources else None, **kwargs, ) @@ -306,20 +314,6 @@ class Scope: return self._references @property - def cte_sources(self): - """ - Sources that are CTEs. - - Returns: - dict[str, Scope]: Mapping of source alias to Scope - """ - return { - alias: scope - for alias, scope in self.sources.items() - if isinstance(scope, Scope) and scope.is_cte - } - - @property def external_columns(self): """ Columns that appear to reference sources in outer scopes. @@ -515,7 +509,10 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): - yield from _traverse_subqueries(scope) + if scope.is_root: + yield from _traverse_select(scope) + else: + yield from _traverse_subqueries(scope) elif isinstance(scope.expression, exp.Table): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): @@ -572,7 +569,7 @@ def _traverse_ctes(scope): for child_scope in _traverse_scope( scope.branch( cte.this, - chain_sources=sources, + cte_sources=sources, outer_column_list=cte.alias_column_names, scope_type=ScopeType.CTE, ) @@ -584,12 +581,14 @@ def _traverse_ctes(scope): if recursive_scope: child_scope.add_source(alias, recursive_scope) + child_scope.cte_sources[alias] = recursive_scope # append the final child_scope yielded if child_scope: scope.cte_scopes.append(child_scope) scope.sources.update(sources) + scope.cte_sources.update(sources) def _is_derived_table(expression: exp.Subquery) -> bool: @@ -725,7 +724,7 @@ def _traverse_ddl(scope): yield from _traverse_ctes(scope) query_scope = scope.branch( - scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources + scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources ) query_scope._collect() query_scope._ctes = scope.ctes + query_scope._ctes diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d4e2e60..6ae08d0 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import functools import itertools @@ -6,10 +8,17 @@ from collections import deque from decimal import Decimal import sqlglot -from sqlglot import exp +from sqlglot import Dialect, exp from sqlglot.helper import first, is_iterable, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + + DateTruncBinaryTransform = t.Callable[ + [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] + ] + # Final means that an expression should not be simplified FINAL = "final" @@ -18,7 +27,9 @@ class UnsupportedUnit(Exception): pass -def simplify(expression, constant_propagation=False): +def simplify( + expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None +): """ Rewrite sqlglot AST to simplify expressions. @@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False): sqlglot.Expression: simplified expression """ + dialect = Dialect.get_or_raise(dialect) + # group by expressions cannot be simplified, for example # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 # the projection must exactly match the group by key for group in expression.find_all(exp.Group): select = group.parent + assert select groups = set(group.expressions) group.meta[FINAL] = True - for e in select.selects: + for e in select.expressions: for node, *_ in e.walk(): if node in groups: e.meta[FINAL] = True @@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False): node = simplify_literals(node, root) node = simplify_equality(node) node = simplify_parens(node) - node = simplify_datetrunc_predicate(node) + node = simplify_datetrunc(node, dialect) + node = sort_comparison(node) if root: expression.replace(node) @@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression: This is done because comparison simplification is only done on lt/lte/gt/gte. """ if isinstance(expression, exp.Between): - return exp.and_( + negate = isinstance(expression.parent, exp.Not) + + expression = exp.and_( exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), copy=False, ) + + if negate: + expression = exp.paren(expression, copy=False) + return expression +COMPLEMENT_COMPARISONS = { + exp.LT: exp.GTE, + exp.GT: exp.LTE, + exp.LTE: exp.GT, + exp.GTE: exp.LT, + exp.EQ: exp.NEQ, + exp.NEQ: exp.EQ, +} + + def simplify_not(expression): """ Demorgan's Law @@ -132,10 +163,15 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): - if is_null(expression.this): + this = expression.this + if is_null(this): return exp.null() - if isinstance(expression.this, exp.Paren): - condition = expression.this.unnest() + if this.__class__ in COMPLEMENT_COMPARISONS: + return COMPLEMENT_COMPARISONS[this.__class__]( + this=this.this, expression=this.expression + ) + if isinstance(this, exp.Paren): + condition = this.unnest() if isinstance(condition, exp.And): return exp.or_( exp.not_(condition.left, copy=False), @@ -150,14 +186,14 @@ def simplify_not(expression): ) if is_null(condition): return exp.null() - if always_true(expression.this): + if always_true(this): return exp.false() - if is_false(expression.this): + if is_false(this): return exp.true() - if isinstance(expression.this, exp.Not): + if isinstance(this, exp.Not): # double negation # NOT NOT x -> x - return expression.this.this + return this.this return expression @@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False): except StopIteration: return expression - # make sure the comparison is always of the form x > 1 instead of 1 < x - if left.__class__ in INVERSE_COMPARISONS and l == ll: - left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) - if right.__class__ in INVERSE_COMPARISONS and r == rl: - right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) - if l.is_number and r.is_number: l = float(l.name) r = float(r.name) @@ -397,13 +427,7 @@ def propagate_constants(expression, root=True): # TODO: create a helper that can be used to detect nested literal expressions such # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too if isinstance(l, exp.Column) and isinstance(r, exp.Literal): - pass - elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): - l, r = r, l - else: - continue - - constant_mapping[l] = (id(l), r) + constant_mapping[l] = (id(l), r) if constant_mapping: for column in find_all_in_scope(expression, exp.Column): @@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if isinstance(expression, COMPARISONS): l, r = expression.left, expression.right - if l.__class__ in INVERSE_OPS: - pass - elif r.__class__ in INVERSE_OPS: - l, r = r, l - else: + if not l.__class__ in INVERSE_OPS: return expression if r.is_number: @@ -650,7 +670,7 @@ def simplify_coalesce(expression): # Find the first constant arg for arg_index, arg in enumerate(coalesce.expressions): - if _is_constant(other): + if _is_constant(arg): break else: return expression @@ -752,7 +772,7 @@ def simplify_conditionals(expression): DateRange = t.Tuple[datetime.date, datetime.date] -def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: +def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: """ Get the date range for a DATE_TRUNC equality comparison: @@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: Returns: tuple of [min, max) or None if a value can never be equal to `date` for `unit` """ - floor = date_floor(date, unit) + floor = date_floor(date, unit, dialect) if date != floor: # This will always be False, except for NULL values. @@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp def _datetrunc_eq( - left: exp.Expression, date: datetime.date, unit: str + left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect ) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if not drange: return None @@ -790,9 +810,9 @@ def _datetrunc_eq( def _datetrunc_neq( - left: exp.Expression, date: datetime.date, unit: str + left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect ) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if not drange: return None @@ -803,41 +823,39 @@ def _datetrunc_neq( ) -DateTruncBinaryTransform = t.Callable[ - [exp.Expression, datetime.date, str], t.Optional[exp.Expression] -] DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), - exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), - exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), - exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), + exp.LT: lambda l, dt, u, d: l + < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), + exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), + exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), + exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), exp.EQ: _datetrunc_eq, exp.NEQ: _datetrunc_neq, } DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} +DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) + return isinstance(left, DATETRUNCS) and _is_date_literal(right) @catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: +def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" comparison = expression.__class__ - if comparison not in DATETRUNC_COMPARISONS: + if isinstance(expression, DATETRUNCS): + date = extract_date(expression.this) + if date and expression.unit: + return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) + elif comparison not in DATETRUNC_COMPARISONS: return expression if isinstance(expression, exp.Binary): l, r = expression.left, expression.right - if _is_datetrunc_predicate(l, r): - pass - elif _is_datetrunc_predicate(r, l): - comparison = INVERSE_COMPARISONS.get(comparison, comparison) - l, r = r, l - else: + if not _is_datetrunc_predicate(l, r): return expression l = t.cast(exp.DateTrunc, l) @@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: if not date: return expression - return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression elif isinstance(expression, exp.In): l = expression.this rs = expression.expressions @@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: date = extract_date(r) if not date: return expression - drange = _datetrunc_range(date, unit) + drange = _datetrunc_range(date, unit, dialect) if drange: ranges.append(drange) @@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: return expression +def sort_comparison(expression: exp.Expression) -> exp.Expression: + if expression.__class__ in COMPLEMENT_COMPARISONS: + l, r = expression.this, expression.expression + l_column = isinstance(l, exp.Column) + r_column = isinstance(r, exp.Column) + l_const = _is_constant(l) + r_const = _is_constant(r) + + if (l_column and not r_column) or (r_const and not l_const): + return expression + if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): + return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( + this=r, expression=l + ) + return expression + + # CROSS joins result in an empty table if the right table is empty. # So we can only simplify certain types of joins to CROSS. # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x @@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1): raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_floor(d: datetime.date, unit: str) -> datetime.date: +def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: if unit == "year": return d.replace(month=1, day=1) if unit == "quarter": @@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date: return d.replace(month=d.month, day=1) if unit == "week": # Assuming week starts on Monday (0) and ends on Sunday (6) - return d - datetime.timedelta(days=d.weekday()) + return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) if unit == "day": return d raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_ceil(d: datetime.date, unit: str) -> datetime.date: - floor = date_floor(d, unit) +def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: + floor = date_floor(d, unit, dialect) if floor == d: return d diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 242fc87..4d35175 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name): ) ): column = exp.Max(this=column) + elif not isinstance(select.parent, exp.Subquery): + return _replace(select.parent, column) parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c7e27a3..3d01a84 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -568,6 +568,7 @@ class Parser(metaclass=_Parser): exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY), exp.Table: lambda self: self._parse_table_parts(), exp.TableAlias: lambda self: self._parse_table_alias(), + exp.When: lambda self: seq_get(self._parse_when_matched(), 0), exp.Where: lambda self: self._parse_where(), exp.Window: lambda self: self._parse_named_window(), exp.With: lambda self: self._parse_with(), @@ -635,6 +636,11 @@ class Parser(metaclass=_Parser): TokenType.HEREDOC_STRING: lambda self, token: self.expression( exp.RawString, this=token.text ), + TokenType.UNICODE_STRING: lambda self, token: self.expression( + exp.UnicodeString, + this=token.text, + escape=self._match_text_seq("UESCAPE") and self._parse_string(), + ), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -907,7 +913,7 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} CLONE_KEYWORDS = {"CLONE", "COPY"} - CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"} + HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"} OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"} OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN} @@ -947,6 +953,10 @@ class Parser(metaclass=_Parser): # Whether the TRIM function expects the characters to trim as its first argument TRIM_PATTERN_FIRST = False + # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand) + MODIFIERS_ATTACHED_TO_UNION = True + UNION_MODIFIERS = {"order", "limit", "offset"} + __slots__ = ( "error_level", "error_message_context", @@ -1162,6 +1172,9 @@ class Parser(metaclass=_Parser): def _find_sql(self, start: Token, end: Token) -> str: return self.sql[start.start : end.end + 1] + def _is_connected(self) -> bool: + return self._prev and self._curr and self._prev.end + 1 == self._curr.start + def _advance(self, times: int = 1) -> None: self._index += times self._curr = seq_get(self._tokens, self._index) @@ -1404,23 +1417,8 @@ class Parser(metaclass=_Parser): if self._match_texts(self.CLONE_KEYWORDS): copy = self._prev.text.lower() == "copy" - clone = self._parse_table(schema=True) - when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper() - clone_kind = ( - self._match(TokenType.L_PAREN) - and self._match_texts(self.CLONE_KINDS) - and self._prev.text.upper() - ) - clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise() - self._match(TokenType.R_PAREN) clone = self.expression( - exp.Clone, - this=clone, - when=when, - kind=clone_kind, - shallow=shallow, - expression=clone_expression, - copy=copy, + exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy ) return self.expression( @@ -2471,13 +2469,7 @@ class Parser(metaclass=_Parser): pattern = None define = ( - self._parse_csv( - lambda: self.expression( - exp.Alias, - alias=self._parse_id_var(any_token=True), - this=self._match(TokenType.ALIAS) and self._parse_conjunction(), - ) - ) + self._parse_csv(self._parse_name_as_expression) if self._match_text_seq("DEFINE") else None ) @@ -3124,6 +3116,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.Connect, start=start, connect=connect) + def _parse_name_as_expression(self) -> exp.Alias: + return self.expression( + exp.Alias, + alias=self._parse_id_var(any_token=True), + this=self._match(TokenType.ALIAS) and self._parse_conjunction(), + ) + + def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]: + if self._match_text_seq("INTERPOLATE"): + return self._parse_wrapped_csv(self._parse_name_as_expression) + return None + def _parse_order( self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False ) -> t.Optional[exp.Expression]: @@ -3131,7 +3135,10 @@ class Parser(metaclass=_Parser): return this return self.expression( - exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) + exp.Order, + this=this, + expressions=self._parse_csv(self._parse_ordered), + interpolate=self._parse_interpolate(), ) def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]: @@ -3161,7 +3168,21 @@ class Parser(metaclass=_Parser): ): nulls_first = True - return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) + if self._match_text_seq("WITH", "FILL"): + with_fill = self.expression( + exp.WithFill, + **{ # type: ignore + "from": self._match(TokenType.FROM) and self._parse_bitwise(), + "to": self._match_text_seq("TO") and self._parse_bitwise(), + "step": self._match_text_seq("STEP") and self._parse_bitwise(), + }, + ) + else: + with_fill = None + + return self.expression( + exp.Ordered, this=this, desc=desc, nulls_first=nulls_first, with_fill=with_fill + ) def _parse_limit( self, this: t.Optional[exp.Expression] = None, top: bool = False @@ -3253,28 +3274,40 @@ class Parser(metaclass=_Parser): return locks def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not self._match_set(self.SET_OPERATIONS): - return this + while this and self._match_set(self.SET_OPERATIONS): + token_type = self._prev.token_type - token_type = self._prev.token_type + if token_type == TokenType.UNION: + operation = exp.Union + elif token_type == TokenType.EXCEPT: + operation = exp.Except + else: + operation = exp.Intersect - if token_type == TokenType.UNION: - expression = exp.Union - elif token_type == TokenType.EXCEPT: - expression = exp.Except - else: - expression = exp.Intersect + comments = self._prev.comments + distinct = self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL) + by_name = self._match_text_seq("BY", "NAME") + expression = self._parse_select(nested=True, parse_set_operation=False) - return self.expression( - expression, - comments=self._prev.comments, - this=this, - distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), - by_name=self._match_text_seq("BY", "NAME"), - expression=self._parse_set_operations( - self._parse_select(nested=True, parse_set_operation=False) - ), - ) + this = self.expression( + operation, + comments=comments, + this=this, + distinct=distinct, + by_name=by_name, + expression=expression, + ) + + if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION: + expression = this.expression + + if expression: + for arg in self.UNION_MODIFIERS: + expr = expression.args.get(arg) + if expr: + this.set(arg, expr.pop()) + + return this def _parse_expression(self) -> t.Optional[exp.Expression]: return self._parse_alias(self._parse_conjunction()) @@ -3595,7 +3628,7 @@ class Parser(metaclass=_Parser): exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span ) else: - this = self.expression(exp.Interval, unit=unit) + this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit)) if maybe_func and check_func: index2 = self._index @@ -4891,8 +4924,8 @@ class Parser(metaclass=_Parser): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() - def _advance_any(self) -> t.Optional[Token]: - if self._curr and self._curr.token_type not in self.RESERVED_TOKENS: + def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]: + if self._curr and (ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS): self._advance() return self._prev return None diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 07ee739..bbc52ab 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -425,16 +425,27 @@ class SetOperation(Step): cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None ) -> Step: assert isinstance(expression, exp.Union) + left = Step.from_expression(expression.left, ctes) + # SELECT 1 UNION SELECT 2 <-- these subqueries don't have names + left.name = left.name or "left" right = Step.from_expression(expression.right, ctes) + right.name = right.name or "right" step = cls( op=expression.__class__, left=left.name, right=right.name, distinct=bool(expression.args.get("distinct")), ) + step.add_dependency(left) step.add_dependency(right) + + limit = expression.args.get("limit") + + if limit: + step.limit = int(limit.text("expression")) + return step def _to_s(self, indent: str) -> t.List[str]: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index e4c3204..de9d4c4 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1,9 +1,10 @@ from __future__ import annotations +import os import typing as t from enum import auto -from sqlglot.errors import TokenError +from sqlglot.errors import SqlglotError, TokenError from sqlglot.helper import AutoName from sqlglot.trie import TrieResult, in_trie, new_trie @@ -11,6 +12,19 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType +try: + from sqlglotrs import ( # type: ignore + Tokenizer as RsTokenizer, + TokenizerDialectSettings as RsTokenizerDialectSettings, + TokenizerSettings as RsTokenizerSettings, + TokenTypeSettings as RsTokenTypeSettings, + ) + + USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1" +except ImportError: + USE_RS_TOKENIZER = False + + class TokenType(AutoName): L_PAREN = auto() R_PAREN = auto() @@ -83,6 +97,7 @@ class TokenType(AutoName): NATIONAL_STRING = auto() RAW_STRING = auto() HEREDOC_STRING = auto() + UNICODE_STRING = auto() # types BIT = auto() @@ -347,6 +362,10 @@ class TokenType(AutoName): TIMESTAMP_SNAPSHOT = auto() +_ALL_TOKEN_TYPES = list(TokenType) +_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)} + + class Token: __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments") @@ -432,6 +451,7 @@ class _Tokenizer(type): **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS), **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS), **_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS), + **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS), } klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) @@ -455,6 +475,46 @@ class _Tokenizer(type): if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) + if USE_RS_TOKENIZER: + settings = RsTokenizerSettings( + white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()}, + single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()}, + keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()}, + numeric_literals=klass.NUMERIC_LITERALS, + identifiers=klass._IDENTIFIERS, + identifier_escapes=klass._IDENTIFIER_ESCAPES, + string_escapes=klass._STRING_ESCAPES, + quotes=klass._QUOTES, + format_strings={ + k: (v1, _TOKEN_TYPE_TO_INDEX[v2]) + for k, (v1, v2) in klass._FORMAT_STRINGS.items() + }, + has_bit_strings=bool(klass.BIT_STRINGS), + has_hex_strings=bool(klass.HEX_STRINGS), + comments=klass._COMMENTS, + var_single_tokens=klass.VAR_SINGLE_TOKENS, + commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS}, + command_prefix_tokens={ + _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS + }, + ) + token_types = RsTokenTypeSettings( + bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING], + break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK], + dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON], + heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING], + hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING], + identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER], + number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER], + parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER], + semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON], + string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING], + var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR], + ) + klass._RS_TOKENIZER = RsTokenizer(settings, token_types) + else: + klass._RS_TOKENIZER = None + return klass @@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer): HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] RAW_STRINGS: t.List[str | t.Tuple[str, str]] = [] HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = [] + UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = [] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] IDENTIFIER_ESCAPES = ['"'] QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] @@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer): _QUOTES: t.Dict[str, str] = {} _STRING_ESCAPES: t.Set[str] = set() _KEYWORD_TRIE: t.Dict = {} + _RS_TOKENIZER: t.Optional[t.Any] = None KEYWORDS: t.Dict[str, TokenType] = { **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, @@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer): # handle numeric literals like in hive (3L = BIGINT) NUMERIC_LITERALS: t.Dict[str, str] = {} - ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/")] @@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer): "_end", "_peek", "_prev_token_line", + "_rs_dialect_settings", ) def __init__(self, dialect: DialectType = None) -> None: from sqlglot.dialects import Dialect self.dialect = Dialect.get_or_raise(dialect) + + if USE_RS_TOKENIZER: + self._rs_dialect_settings = RsTokenizerDialectSettings( + escape_sequences=self.dialect.ESCAPE_SEQUENCES, + identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT, + ) + self.reset() def reset(self) -> None: @@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer): def tokenize(self, sql: str) -> t.List[Token]: """Returns a list of tokens corresponding to the SQL string `sql`.""" + if USE_RS_TOKENIZER: + return self.tokenize_rs(sql) + self.reset() self.sql = sql self.size = len(sql) @@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer): # Ensures we don't count an extra line if we get a \r\n line break sequence if self._char == "\r" and self._peek == "\n": i = 2 + self._start += 1 self._col = 1 self._line += 1 @@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer): raise TokenError( f"Numeric string contains invalid characters from {self._line}:{self._start}" ) - else: - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text self._add(token_type, text) return True @@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer): text += self.sql[current : self._current - 1] return text + + def tokenize_rs(self, sql: str) -> t.List[Token]: + if not self._RS_TOKENIZER: + raise SqlglotError("Rust tokenizer is not available") + + try: + tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings) + for token in tokens: + token.token_type = _ALL_TOKEN_TYPES[token.token_type_index] + return tokens + except Exception as e: + raise TokenError(str(e)) |