diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-11 12:46:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-11 12:46:10 +0000 |
commit | d142aecb38fbfd35bf2a0732f5391a807bff3a5e (patch) | |
tree | f5430f0e6a127d39c663e958045aa1bbb462c58b /sqlglot | |
parent | Releasing debian version 15.0.0-1. (diff) | |
download | sqlglot-d142aecb38fbfd35bf2a0732f5391a807bff3a5e.tar.xz sqlglot-d142aecb38fbfd35bf2a0732f5391a807bff3a5e.zip |
Merging upstream version 15.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 56 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 27 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 41 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 17 | ||||
-rw-r--r-- | sqlglot/expressions.py | 72 | ||||
-rw-r--r-- | sqlglot/generator.py | 93 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 4 | ||||
-rw-r--r-- | sqlglot/parser.py | 184 | ||||
-rw-r--r-- | sqlglot/serde.py | 8 | ||||
-rw-r--r-- | sqlglot/tokens.py | 45 | ||||
-rw-r--r-- | sqlglot/transforms.py | 11 |
21 files changed, 465 insertions, 169 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1a58337..5b10852 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -327,6 +327,8 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"} + def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) if isinstance(first_arg, exp.Subqueryable): diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index c8a9525..fc48379 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -27,14 +27,15 @@ class ClickHouse(Dialect): class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] + STRING_ESCAPES = ["'", "\\"] BIT_STRINGS = [("0b", "")] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "ASOF": TokenType.ASOF, "ATTACH": TokenType.COMMAND, "DATETIME64": TokenType.DATETIME64, + "DICTIONARY": TokenType.DICTIONARY, "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, @@ -97,7 +98,6 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { TokenType.ANY, - TokenType.ASOF, TokenType.SEMI, TokenType.ANTI, TokenType.SETTINGS, @@ -182,7 +182,7 @@ class ClickHouse(Dialect): return self.expression(exp.CTE, this=statement, alias=statement and statement.this) - def _parse_join_side_and_kind( + def _parse_join_parts( self, ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: is_global = self._match(TokenType.GLOBAL) and self._prev @@ -201,7 +201,7 @@ class ClickHouse(Dialect): join = super()._parse_join(skip_join_token) if join: - join.set("global", join.args.pop("natural", None)) + join.set("global", join.args.pop("method", None)) return join def _parse_function( @@ -245,6 +245,23 @@ class ClickHouse(Dialect): ) -> t.List[t.Optional[exp.Expression]]: return super()._parse_wrapped_id_vars(optional=True) + def _parse_primary_key( + self, wrapped_optional: bool = False, in_props: bool = False + ) -> exp.Expression: + return super()._parse_primary_key( + wrapped_optional=wrapped_optional or in_props, in_props=in_props + ) + + def _parse_on_property(self) -> t.Optional[exp.Property]: + index = self._index + if self._match_text_seq("CLUSTER"): + this = self._parse_id_var() + if this: + return self.expression(exp.OnCluster, this=this) + else: + self._retreat(index) + return None + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -292,6 +309,7 @@ class ClickHouse(Dialect): **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.OnCluster: exp.Properties.Location.POST_NAME, } JOIN_HINTS = False @@ -299,6 +317,18 @@ class ClickHouse(Dialect): EXPLICIT_UNION = True GROUPINGS_SEP = "" + # there's no list in docs, but it can be found in Clickhouse code + # see `ClickHouse/src/Parsers/ParserCreate*.cpp` + ON_CLUSTER_TARGETS = { + "DATABASE", + "TABLE", + "VIEW", + "DICTIONARY", + "INDEX", + "FUNCTION", + "NAMED COLLECTION", + } + def cte_sql(self, expression: exp.CTE) -> str: if isinstance(expression.this, exp.Alias): return self.sql(expression, "this") @@ -321,3 +351,21 @@ class ClickHouse(Dialect): def placeholder_sql(self, expression: exp.Placeholder) -> str: return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" + + def oncluster_sql(self, expression: exp.OnCluster) -> str: + return f"ON CLUSTER {self.sql(expression, 'this')}" + + def createable_sql( + self, + expression: exp.Create, + locations: dict[exp.Properties.Location, list[exp.Property]], + ) -> str: + kind = self.sql(expression, "kind").upper() + if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME): + this_name = self.sql(expression.this, "this") + this_properties = " ".join( + [self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]] + ) + this_schema = self.schema_columns_sql(expression.this) + return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}" + return super().createable_sql(expression, locations) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 890a3c3..4958bc6 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -104,6 +104,10 @@ class _Dialect(type): klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) + klass.tokenizer_class.identifiers_can_start_with_digit = ( + klass.identifiers_can_start_with_digit + ) + return klass @@ -111,6 +115,7 @@ class Dialect(metaclass=_Dialect): index_offset = 0 unnest_column_only = False alias_post_tablesample = False + identifiers_can_start_with_digit = False normalize_functions: t.Optional[str] = "upper" null_ordering = "nulls_are_small" @@ -231,6 +236,7 @@ class Dialect(metaclass=_Dialect): "time_trie": self.inverse_time_trie, "unnest_column_only": self.unnest_column_only, "alias_post_tablesample": self.alias_post_tablesample, + "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, "normalize_functions": self.normalize_functions, "null_ordering": self.null_ordering, **opts, @@ -443,7 +449,7 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) - if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): + if isinstance(this, exp.Cast) and this.is_type("date"): return exp.DateTrunc(unit=unit, this=this) return exp.TimestampTrunc(this=this, unit=unit) @@ -468,6 +474,25 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s ) +def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: + expression = expression.copy() + return self.sql( + exp.Substring( + this=expression.this, start=exp.Literal.number(1), length=expression.expression + ) + ) + + +def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: + expression = expression.copy() + return self.sql( + exp.Substring( + this=expression.this, + start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), + ) + ) + + def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 662882d..f31da73 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -71,7 +71,7 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: - if expression.this == exp.DataType.Type.ARRAY: + if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index fbd626a..650a1e1 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, format_time_lambda, if_sql, + left_to_substring_sql, locate_to_strposition, max_or_greatest, min_or_least, @@ -17,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, no_trycast_sql, rename_func, + right_to_substring_sql, strposition_to_locate_sql, struct_extract_sql, timestrtotime_sql, @@ -89,7 +91,7 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s annotate_types(this) - if this.type.is_type(exp.DataType.Type.JSON): + if this.type.is_type("json"): return self.sql(this) return self.func("TO_JSON", this, expression.args.get("options")) @@ -149,6 +151,7 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str class Hive(Dialect): alias_post_tablesample = True + identifiers_can_start_with_digit = True time_mapping = { "y": "%Y", @@ -190,7 +193,6 @@ class Hive(Dialect): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] ENCODE = "utf-8" - IDENTIFIER_CAN_START_WITH_DIGIT = True KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -276,6 +278,39 @@ class Hive(Dialect): "cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"), } + def _parse_types( + self, check_func: bool = False, schema: bool = False + ) -> t.Optional[exp.Expression]: + """ + Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to + STRING in all contexts except for schema definitions. For example, this is in Spark v3.4.0: + + spark-sql (default)> select cast(1234 as varchar(2)); + 23/06/06 15:51:18 WARN CharVarcharUtils: The Spark cast operator does not support + char/varchar type and simply treats them as string type. Please use string type + directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString + to true, so that Spark treat them as string type as same as Spark 3.0 and earlier + + 1234 + Time taken: 4.265 seconds, Fetched 1 row(s) + + This shows that Spark doesn't truncate the value into '12', which is inconsistent with + what other dialects (e.g. postgres) do, so we need to drop the length to transpile correctly. + + Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html + """ + this = super()._parse_types(check_func=check_func, schema=schema) + + if this and not schema: + return this.transform( + lambda node: node.replace(exp.DataType.build("text")) + if isinstance(node, exp.DataType) and node.is_type("char", "varchar") + else node, + copy=False, + ) + + return this + class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False @@ -323,6 +358,7 @@ class Hive(Dialect): exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONFormat: _json_format_sql, + exp.Left: left_to_substring_sql, exp.Map: var_map_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, @@ -332,6 +368,7 @@ class Hive(Dialect): exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), + exp.Right: right_to_substring_sql, exp.SafeDivide: no_safe_divide_sql, exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 2b41860..75023ff 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -186,9 +186,6 @@ class MySQL(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), - "LEFT": lambda args: exp.Substring( - this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1) - ), "LOCATE": locate_to_strposition, "STR_TO_DATE": _str_to_date, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index ab61880..8d84024 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -18,7 +18,9 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, timestamptrunc_sql, + timestrtotime_sql, trim_sql, + ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser @@ -104,7 +106,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: - if expression.this == exp.DataType.Type.ARRAY: + if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) @@ -353,12 +355,13 @@ class Postgres(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, exp.TimestampTrunc: timestamptrunc_sql, - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, + exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, exp.GroupConcat: _string_agg_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 52a04a4..d839864 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -8,10 +8,12 @@ from sqlglot.dialects.dialect import ( date_trunc_to_time, format_time_lambda, if_sql, + left_to_substring_sql, no_ilike_sql, no_pivot_sql, no_safe_divide_sql, rename_func, + right_to_substring_sql, struct_extract_sql, timestamptrunc_sql, timestrtotime_sql, @@ -30,7 +32,7 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: sql = self.datatype_sql(expression) - if expression.this == exp.DataType.Type.TIMESTAMPTZ: + if expression.is_type("timestamptz"): sql = f"{sql} WITH TIME ZONE" return sql @@ -240,6 +242,7 @@ class Presto(Dialect): INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False TABLE_HINTS = False + IS_BOOL = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -272,6 +275,7 @@ class Presto(Dialect): exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DataType: _datatype_sql, exp.DateAdd: lambda self, e: self.func( @@ -292,11 +296,13 @@ class Presto(Dialect): exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, + exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, + exp.Right: right_to_substring_sql, exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( @@ -319,6 +325,7 @@ class Presto(Dialect): exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), + exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, @@ -356,7 +363,7 @@ class Presto(Dialect): else: target_type = None - if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP): + if target_type and target_type.is_type("timestamp"): to = target_type.copy() if target_type is start.to: diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 55e393a..b0a6774 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp, transforms +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -24,26 +25,29 @@ class Redshift(Postgres): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), + this=exp.TsOrDsToDate(this=seq_get(args, 2)), expression=seq_get(args, 1), unit=seq_get(args, 0), ), "DATEDIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), + this=exp.TsOrDsToDate(this=seq_get(args, 2)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), unit=seq_get(args, 0), ), "NVL": exp.Coalesce.from_arg_list, + "STRTOL": exp.FromBase.from_arg_list, } CONVERT_TYPE_FIRST = True - def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: - this = super()._parse_types(check_func=check_func) + def _parse_types( + self, check_func: bool = False, schema: bool = False + ) -> t.Optional[exp.Expression]: + this = super()._parse_types(check_func=check_func, schema=schema) if ( isinstance(this, exp.DataType) - and this.this == exp.DataType.Type.VARCHAR + and this.is_type("varchar") and this.expressions and this.expressions[0].this == exp.column("MAX") ): @@ -99,10 +103,12 @@ class Redshift(Postgres): ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), + exp.FromBase: rename_func("STRTOL"), exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.TsOrDsToDate: lambda self, e: self.sql(e.this), } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots @@ -158,7 +164,7 @@ class Redshift(Postgres): without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert `TEXT` to `VARCHAR`. """ - if expression.this == exp.DataType.Type.TEXT: + if expression.is_type("text"): expression = expression.copy() expression.set("this", exp.DataType.Type.VARCHAR) precision = expression.args.get("expressions") diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 756e8e9..821d991 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -153,9 +153,9 @@ def _nullifzero_to_if(args: t.List) -> exp.Expression: def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: - if expression.this == exp.DataType.Type.ARRAY: + if expression.is_type("array"): return "ARRAY" - elif expression.this == exp.DataType.Type.MAP: + elif expression.is_type("map"): return "OBJECT" return self.datatype_sql(expression) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 912b86b..bf24240 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -110,11 +110,6 @@ class Spark2(Hive): **Hive.Parser.FUNCTIONS, "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, - "LEFT": lambda args: exp.Substring( - this=seq_get(args, 0), - start=exp.Literal.number(1), - length=seq_get(args, 1), - ), "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( this=seq_get(args, 0), expression=seq_get(args, 1), @@ -123,14 +118,6 @@ class Spark2(Hive): this=seq_get(args, 0), expression=seq_get(args, 1), ), - "RIGHT": lambda args: exp.Substring( - this=seq_get(args, 0), - start=exp.Sub( - this=exp.Length(this=seq_get(args, 0)), - expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), - ), - length=seq_get(args, 1), - ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, "AGGREGATE": exp.Reduce.from_arg_list, @@ -240,17 +227,17 @@ class Spark2(Hive): TRANSFORMS.pop(exp.ArrayJoin) TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) + TRANSFORMS.pop(exp.Left) + TRANSFORMS.pop(exp.Right) WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False def cast_sql(self, expression: exp.Cast) -> str: - if isinstance(expression.this, exp.Cast) and expression.this.is_type( - exp.DataType.Type.JSON - ): + if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"): schema = f"'{self.sql(expression, 'to')}'" return self.func("FROM_JSON", expression.this.this, schema) - if expression.to.is_type(exp.DataType.Type.JSON): + if expression.is_type("json"): return self.func("TO_JSON", expression.this) return super(Hive.Generator, self).cast_sql(expression) @@ -260,7 +247,7 @@ class Spark2(Hive): expression, sep=": " if isinstance(expression.parent, exp.DataType) - and expression.parent.is_type(exp.DataType.Type.STRUCT) + and expression.parent.is_type("struct") else sep, ) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 56e7773..4e800b0 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -132,7 +132,7 @@ class SQLite(Dialect): LIMIT_FETCH = "LIMIT" def cast_sql(self, expression: exp.Cast) -> str: - if expression.to.this == exp.DataType.Type.DATE: + if expression.is_type("date"): return self.func("DATE", expression.this) return super().cast_sql(expression) diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 9b39178..514aecb 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -183,3 +183,20 @@ class Teradata(Dialect): each_sql = f" EACH {each_sql}" if each_sql else "" return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})" + + def createable_sql( + self, + expression: exp.Create, + locations: dict[exp.Properties.Location, list[exp.Property]], + ) -> str: + kind = self.sql(expression, "kind").upper() + if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME): + this_name = self.sql(expression.this, "this") + this_properties = self.properties( + exp.Properties(expressions=locations[exp.Properties.Location.POST_NAME]), + wrapped=False, + prefix=",", + ) + this_schema = self.schema_columns_sql(expression.this) + return f"{this_name}{this_properties}{self.sep()}{this_schema}" + return super().createable_sql(expression, locations) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a4c4e95..da4a4ed 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1653,12 +1653,16 @@ class Join(Expression): "side": False, "kind": False, "using": False, - "natural": False, + "method": False, "global": False, "hint": False, } @property + def method(self) -> str: + return self.text("method").upper() + + @property def kind(self) -> str: return self.text("kind").upper() @@ -1913,6 +1917,24 @@ class LanguageProperty(Property): arg_types = {"this": True} +class DictProperty(Property): + arg_types = {"this": True, "kind": True, "settings": False} + + +class DictSubProperty(Property): + pass + + +class DictRange(Property): + arg_types = {"this": True, "min": True, "max": True} + + +# Clickhouse CREATE ... ON CLUSTER modifier +# https://clickhouse.com/docs/en/sql-reference/distributed-ddl +class OnCluster(Property): + arg_types = {"this": True} + + class LikeProperty(Property): arg_types = {"this": True, "expressions": False} @@ -2797,12 +2819,12 @@ class Select(Subqueryable): Returns: Select: the modified expression. """ - parse_args = {"dialect": dialect, **opts} + parse_args: t.Dict[str, t.Any] = {"dialect": dialect, **opts} try: - expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore + expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) except ParseError: - expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore + expression = maybe_parse(expression, into=(Join, Expression), **parse_args) join = expression if isinstance(expression, Join) else Join(this=expression) @@ -2810,14 +2832,14 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: - natural: t.Optional[Token] + method: t.Optional[Token] side: t.Optional[Token] kind: t.Optional[Token] - natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore + method, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore - if natural: - join.set("natural", True) + if method: + join.set("method", method.text) if side: join.set("side", side.text) if kind: @@ -3222,6 +3244,18 @@ class DataType(Expression): DATE = auto() DATETIME = auto() DATETIME64 = auto() + INT4RANGE = auto() + INT4MULTIRANGE = auto() + INT8RANGE = auto() + INT8MULTIRANGE = auto() + NUMRANGE = auto() + NUMMULTIRANGE = auto() + TSRANGE = auto() + TSMULTIRANGE = auto() + TSTZRANGE = auto() + TSTZMULTIRANGE = auto() + DATERANGE = auto() + DATEMULTIRANGE = auto() DECIMAL = auto() DOUBLE = auto() FLOAT = auto() @@ -3331,8 +3365,8 @@ class DataType(Expression): return DataType(**{**data_type_exp.args, **kwargs}) - def is_type(self, dtype: DataType.Type) -> bool: - return self.this == dtype + def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: + return any(self.this == DataType.build(dtype).this for dtype in dtypes) # https://www.postgresql.org/docs/15/datatype-pseudo.html @@ -3846,8 +3880,8 @@ class Cast(Func): def output_name(self) -> str: return self.name - def is_type(self, dtype: DataType.Type) -> bool: - return self.to.is_type(dtype) + def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: + return self.to.is_type(*dtypes) class CastToStrType(Func): @@ -4130,8 +4164,16 @@ class Least(Func): is_var_len_args = True +class Left(Func): + arg_types = {"this": True, "expression": True} + + +class Right(Func): + arg_types = {"this": True, "expression": True} + + class Length(Func): - pass + _sql_names = ["LENGTH", "LEN"] class Levenshtein(Func): @@ -4356,6 +4398,10 @@ class NumberToStr(Func): arg_types = {"this": True, "format": True} +class FromBase(Func): + arg_types = {"this": True, "expression": True} + + class Struct(Func): arg_types = {"expressions": True} is_var_len_args = True diff --git a/sqlglot/generator.py b/sqlglot/generator.py index f1ec398..97cbe15 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -44,6 +44,8 @@ class Generator: Default: "upper" alias_post_tablesample (bool): if the table alias comes after tablesample Default: False + identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit + Default: False unsupported_level (ErrorLevel): determines the generator's behavior when it encounters unsupported expressions. Default ErrorLevel.WARN. null_ordering (str): Indicates the default null ordering method to use if not explicitly set. @@ -188,6 +190,8 @@ class Generator: exp.Cluster: exp.Properties.Location.POST_SCHEMA, exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME, exp.DefinerProperty: exp.Properties.Location.POST_CREATE, + exp.DictRange: exp.Properties.Location.POST_SCHEMA, + exp.DictProperty: exp.Properties.Location.POST_SCHEMA, exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, @@ -233,6 +237,7 @@ class Generator: JOIN_HINTS = True TABLE_HINTS = True + IS_BOOL = True RESERVED_KEYWORDS: t.Set[str] = set() WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) @@ -264,6 +269,7 @@ class Generator: "index_offset", "unnest_column_only", "alias_post_tablesample", + "identifiers_can_start_with_digit", "normalize_functions", "unsupported_level", "unsupported_messages", @@ -304,6 +310,7 @@ class Generator: index_offset=0, unnest_column_only=False, alias_post_tablesample=False, + identifiers_can_start_with_digit=False, normalize_functions="upper", unsupported_level=ErrorLevel.WARN, null_ordering=None, @@ -337,6 +344,7 @@ class Generator: self.index_offset = index_offset self.unnest_column_only = unnest_column_only self.alias_post_tablesample = alias_post_tablesample + self.identifiers_can_start_with_digit = identifiers_can_start_with_digit self.normalize_functions = normalize_functions self.unsupported_level = unsupported_level self.unsupported_messages = [] @@ -634,35 +642,31 @@ class Generator: this = f" {this}" if this else "" return f"UNIQUE{this}" + def createable_sql( + self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]] + ) -> str: + return self.sql(expression, "this") + def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() properties = expression.args.get("properties") - properties_exp = expression.copy() properties_locs = self.locate_properties(properties) if properties else {} + + this = self.createable_sql(expression, properties_locs) + + properties_sql = "" if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get( exp.Properties.Location.POST_WITH ): - properties_exp.set( - "properties", + properties_sql = self.sql( exp.Properties( expressions=[ *properties_locs[exp.Properties.Location.POST_SCHEMA], *properties_locs[exp.Properties.Location.POST_WITH], ] - ), + ) ) - if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME): - this_name = self.sql(expression.this, "this") - this_properties = self.properties( - exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]), - wrapped=False, - ) - this_schema = f"({self.expressions(expression.this)})" - this = f"{this_name}, {this_properties} {this_schema}" - properties_sql = "" - else: - this = self.sql(expression, "this") - properties_sql = self.sql(properties_exp, "properties") + begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") if expression_sql: @@ -894,6 +898,7 @@ class Generator: expression.quoted or should_identify(text, self.identify) or lower in self.RESERVED_KEYWORDS + or (not self.identifiers_can_start_with_digit and text[:1].isdigit()) ): text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1082,7 +1087,7 @@ class Generator: def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: kind = expression.args.get("kind") - this: str = f" {this}" if expression.this else "" + this = f" {self.sql(expression, 'this')}" if expression.this else "" for_or_in = expression.args.get("for_or_in") lock_type = expression.args.get("lock_type") override = " OVERRIDE" if expression.args.get("override") else "" @@ -1313,7 +1318,7 @@ class Generator: op_sql = " ".join( op for op in ( - "NATURAL" if expression.args.get("natural") else None, + expression.method, "GLOBAL" if expression.args.get("global") else None, expression.side, expression.kind, @@ -1573,9 +1578,12 @@ class Generator: def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") this = f"{this} " if this else "" - sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + sql = self.schema_columns_sql(expression) return f"{this}{sql}" + def schema_columns_sql(self, expression: exp.Schema) -> str: + return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else "" @@ -1643,32 +1651,26 @@ class Generator: def window_sql(self, expression: exp.Window) -> str: this = self.sql(expression, "this") - partition = self.partition_by_sql(expression) - order = expression.args.get("order") - order_sql = self.order_sql(order, flat=True) if order else "" - - partition_sql = partition + " " if partition and order else partition - - spec = expression.args.get("spec") - spec_sql = " " + self.windowspec_sql(spec) if spec else "" - + order = self.order_sql(order, flat=True) if order else "" + spec = self.sql(expression, "spec") alias = self.sql(expression, "alias") over = self.sql(expression, "over") or "OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else over}" first = expression.args.get("first") - if first is not None: - first = " FIRST " if first else " LAST " - first = first or "" + if first is None: + first = "" + else: + first = "FIRST" if first else "LAST" if not partition and not order and not spec and alias: return f"{this} {alias}" - window_args = alias + first + partition_sql + order_sql + spec_sql - - return f"{this} ({window_args.strip()})" + args = " ".join(arg for arg in (alias, first, partition, order, spec) if arg) + return f"{this} ({args})" def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str: partition = self.expressions(expression, key="partition_by", flat=True) @@ -2125,6 +2127,10 @@ class Generator: return self.binary(expression, "ILIKE ANY") def is_sql(self, expression: exp.Is) -> str: + if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean): + return self.sql( + expression.this if expression.expression.this else exp.not_(expression.this) + ) return self.binary(expression, "IS") def like_sql(self, expression: exp.Like) -> str: @@ -2322,6 +2328,25 @@ class Generator: return self.sql(exp.cast(expression.this, "text")) + def dictproperty_sql(self, expression: exp.DictProperty) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + settings_sql = self.expressions(expression, key="settings", sep=" ") + args = f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" if settings_sql else "()" + return f"{this}({kind}{args})" + + def dictrange_sql(self, expression: exp.DictRange) -> str: + this = self.sql(expression, "this") + max = self.sql(expression, "max") + min = self.sql(expression, "min") + return f"{this}(MIN {min} MAX {max})" + + def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str: + return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}" + + def oncluster_sql(self, expression: exp.OnCluster) -> str: + return "" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 43436cb..4e0c3a1 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,7 +1,7 @@ from sqlglot import exp from sqlglot.helper import tsort -JOIN_ATTRS = ("on", "side", "kind", "using", "natural") +JOIN_ATTRS = ("on", "side", "kind", "using", "method") def optimize_joins(expression): diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 96dda33..b89a82b 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -10,10 +10,10 @@ def pushdown_predicates(expression): Example: >>> import sqlglot - >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" + >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() - 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' + 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' Args: expression (sqlglot.Expression): expression to optimize diff --git a/sqlglot/parser.py b/sqlglot/parser.py index e77bb5a..96bd6e3 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -155,6 +155,18 @@ class Parser(metaclass=_Parser): TokenType.DATETIME, TokenType.DATETIME64, TokenType.DATE, + TokenType.INT4RANGE, + TokenType.INT4MULTIRANGE, + TokenType.INT8RANGE, + TokenType.INT8MULTIRANGE, + TokenType.NUMRANGE, + TokenType.NUMMULTIRANGE, + TokenType.TSRANGE, + TokenType.TSMULTIRANGE, + TokenType.TSTZRANGE, + TokenType.TSTZMULTIRANGE, + TokenType.DATERANGE, + TokenType.DATEMULTIRANGE, TokenType.DECIMAL, TokenType.BIGDECIMAL, TokenType.UUID, @@ -193,6 +205,7 @@ class Parser(metaclass=_Parser): TokenType.SCHEMA, TokenType.TABLE, TokenType.VIEW, + TokenType.DICTIONARY, } CREATABLES = { @@ -220,6 +233,7 @@ class Parser(metaclass=_Parser): TokenType.DELETE, TokenType.DESC, TokenType.DESCRIBE, + TokenType.DICTIONARY, TokenType.DIV, TokenType.END, TokenType.EXECUTE, @@ -272,6 +286,7 @@ class Parser(metaclass=_Parser): TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { TokenType.APPLY, + TokenType.ASOF, TokenType.FULL, TokenType.LEFT, TokenType.LOCK, @@ -375,6 +390,11 @@ class Parser(metaclass=_Parser): TokenType.EXCEPT, } + JOIN_METHODS = { + TokenType.NATURAL, + TokenType.ASOF, + } + JOIN_SIDES = { TokenType.LEFT, TokenType.RIGHT, @@ -465,7 +485,7 @@ class Parser(metaclass=_Parser): exp.Where: lambda self: self._parse_where(), exp.Window: lambda self: self._parse_named_window(), exp.With: lambda self: self._parse_with(), - "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), + "JOIN_TYPE": lambda self: self._parse_join_parts(), } STATEMENT_PARSERS = { @@ -580,6 +600,8 @@ class Parser(metaclass=_Parser): ), "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), + "LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"), + "LIFETIME": lambda self: self._parse_dict_range(this="LIFETIME"), "LIKE": lambda self: self._parse_create_like(), "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), "LOCK": lambda self: self._parse_locking(), @@ -594,7 +616,8 @@ class Parser(metaclass=_Parser): "PARTITION BY": lambda self: self._parse_partitioned_by(), "PARTITIONED BY": lambda self: self._parse_partitioned_by(), "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), - "PRIMARY KEY": lambda self: self._parse_primary_key(), + "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True), + "RANGE": lambda self: self._parse_dict_range(this="RANGE"), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), @@ -603,6 +626,7 @@ class Parser(metaclass=_Parser): exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item) ), "SORTKEY": lambda self: self._parse_sortkey(), + "SOURCE": lambda self: self._parse_dict_property(this="SOURCE"), "STABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("STABLE") ), @@ -1133,13 +1157,16 @@ class Parser(metaclass=_Parser): begin = None clone = None + def extend_props(temp_props: t.Optional[exp.Expression]) -> None: + nonlocal properties + if properties and temp_props: + properties.expressions.extend(temp_props.expressions) + elif temp_props: + properties = temp_props + if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + extend_props(self._parse_properties()) self._match(TokenType.ALIAS) begin = self._match(TokenType.BEGIN) @@ -1154,21 +1181,13 @@ class Parser(metaclass=_Parser): table_parts = self._parse_table_parts(schema=True) # exp.Properties.Location.POST_NAME - if self._match(TokenType.COMMA): - temp_properties = self._parse_properties(before=True) - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + self._match(TokenType.COMMA) + extend_props(self._parse_properties(before=True)) this = self._parse_schema(this=table_parts) # exp.Properties.Location.POST_SCHEMA and POST_WITH - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + extend_props(self._parse_properties()) self._match(TokenType.ALIAS) @@ -1178,11 +1197,7 @@ class Parser(metaclass=_Parser): or self._match(TokenType.WITH, advance=False) or self._match(TokenType.L_PAREN, advance=False) ): - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + extend_props(self._parse_properties()) expression = self._parse_ddl_select() @@ -1192,11 +1207,7 @@ class Parser(metaclass=_Parser): index = self._parse_index() # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + extend_props(self._parse_properties()) if not index: break @@ -1888,8 +1899,16 @@ class Parser(metaclass=_Parser): this = self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): - this = self._parse_table() if table else self._parse_select(nested=True) - this = self._parse_set_operations(self._parse_query_modifiers(this)) + if self._match(TokenType.PIVOT): + this = self._parse_simplified_pivot() + elif self._match(TokenType.FROM): + this = exp.select("*").from_( + t.cast(exp.From, self._parse_from(skip_from_token=True)) + ) + else: + this = self._parse_table() if table else self._parse_select(nested=True) + this = self._parse_set_operations(self._parse_query_modifiers(this)) + self._match_r_paren() # early return so that subquery unions aren't parsed again @@ -1902,10 +1921,6 @@ class Parser(metaclass=_Parser): expressions=self._parse_csv(self._parse_value), alias=self._parse_table_alias(), ) - elif self._match(TokenType.PIVOT): - this = self._parse_simplified_pivot() - elif self._match(TokenType.FROM): - this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True))) else: this = None @@ -2154,11 +2169,11 @@ class Parser(metaclass=_Parser): return expression - def _parse_join_side_and_kind( + def _parse_join_parts( self, ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: return ( - self._match(TokenType.NATURAL) and self._prev, + self._match_set(self.JOIN_METHODS) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) @@ -2168,14 +2183,14 @@ class Parser(metaclass=_Parser): return self.expression(exp.Join, this=self._parse_table()) index = self._index - natural, side, kind = self._parse_join_side_and_kind() + method, side, kind = self._parse_join_parts() hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None join = self._match(TokenType.JOIN) if not skip_join_token and not join: self._retreat(index) kind = None - natural = None + method = None side = None outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False) @@ -2187,12 +2202,10 @@ class Parser(metaclass=_Parser): if outer_apply: side = Token(TokenType.LEFT, "LEFT") - kwargs: t.Dict[ - str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]] - ] = {"this": self._parse_table()} + kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()} - if natural: - kwargs["natural"] = True + if method: + kwargs["method"] = method.text if side: kwargs["side"] = side.text if kind: @@ -2205,7 +2218,7 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() - return self.expression(exp.Join, **kwargs) # type: ignore + return self.expression(exp.Join, **kwargs) def _parse_index( self, @@ -2886,7 +2899,9 @@ class Parser(metaclass=_Parser): exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True) ) - def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: + def _parse_types( + self, check_func: bool = False, schema: bool = False + ) -> t.Optional[exp.Expression]: index = self._index prefix = self._match_text_seq("SYSUDTLIB", ".") @@ -2908,7 +2923,9 @@ class Parser(metaclass=_Parser): if is_struct: expressions = self._parse_csv(self._parse_struct_types) elif nested: - expressions = self._parse_csv(self._parse_types) + expressions = self._parse_csv( + lambda: self._parse_types(check_func=check_func, schema=schema) + ) else: expressions = self._parse_csv(self._parse_type_size) @@ -2943,7 +2960,9 @@ class Parser(metaclass=_Parser): if is_struct: expressions = self._parse_csv(self._parse_struct_types) else: - expressions = self._parse_csv(self._parse_types) + expressions = self._parse_csv( + lambda: self._parse_types(check_func=check_func, schema=schema) + ) if not self._match(TokenType.GT): self.raise_error("Expecting >") @@ -3038,11 +3057,7 @@ class Parser(metaclass=_Parser): else exp.Literal.string(value) ) else: - field = ( - self._parse_star() - or self._parse_function(anonymous=True) - or self._parse_id_var() - ) + field = self._parse_field(anonymous_func=True) if isinstance(field, exp.Func): # bigquery allows function calls like x.y.count(...) @@ -3113,10 +3128,11 @@ class Parser(metaclass=_Parser): self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None, + anonymous_func: bool = False, ) -> t.Optional[exp.Expression]: return ( self._parse_primary() - or self._parse_function() + or self._parse_function(anonymous=anonymous_func) or self._parse_id_var(any_token=any_token, tokens=tokens) ) @@ -3270,7 +3286,7 @@ class Parser(metaclass=_Parser): # column defs are not really columns, they're identifiers if isinstance(this, exp.Column): this = this.this - kind = self._parse_types() + kind = self._parse_types(schema=True) if self._match_text_seq("FOR", "ORDINALITY"): return self.expression(exp.ColumnDef, this=this, ordinality=True) @@ -3483,16 +3499,18 @@ class Parser(metaclass=_Parser): exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore ) - def _parse_primary_key(self) -> exp.Expression: + def _parse_primary_key( + self, wrapped_optional: bool = False, in_props: bool = False + ) -> exp.Expression: desc = ( self._match_set((TokenType.ASC, TokenType.DESC)) and self._prev.token_type == TokenType.DESC ) - if not self._match(TokenType.L_PAREN, advance=False): + if not in_props and not self._match(TokenType.L_PAREN, advance=False): return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc) - expressions = self._parse_wrapped_csv(self._parse_field) + expressions = self._parse_wrapped_csv(self._parse_field, optional=wrapped_optional) options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) @@ -3509,10 +3527,11 @@ class Parser(metaclass=_Parser): return this bracket_kind = self._prev.token_type - expressions: t.List[t.Optional[exp.Expression]] if self._match(TokenType.COLON): - expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] + expressions: t.List[t.Optional[exp.Expression]] = [ + self.expression(exp.Slice, expression=self._parse_conjunction()) + ] else: expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction())) @@ -4011,22 +4030,15 @@ class Parser(metaclass=_Parser): self, any_token: bool = True, tokens: t.Optional[t.Collection[TokenType]] = None, - prefix_tokens: t.Optional[t.Collection[TokenType]] = None, ) -> t.Optional[exp.Expression]: identifier = self._parse_identifier() if identifier: return identifier - prefix = "" - - if prefix_tokens: - while self._match_set(prefix_tokens): - prefix += self._prev.text - if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): quoted = self._prev.token_type == TokenType.STRING - return exp.Identifier(this=prefix + self._prev.text, quoted=quoted) + return exp.Identifier(this=self._prev.text, quoted=quoted) return None @@ -4472,6 +4484,44 @@ class Parser(metaclass=_Parser): size = len(start.text) return exp.Command(this=text[:size], expression=text[size:]) + def _parse_dict_property(self, this: str) -> exp.DictProperty: + settings = [] + + self._match_l_paren() + kind = self._parse_id_var() + + if self._match(TokenType.L_PAREN): + while True: + key = self._parse_id_var() + value = self._parse_primary() + + if not key and value is None: + break + settings.append(self.expression(exp.DictSubProperty, this=key, value=value)) + self._match(TokenType.R_PAREN) + + self._match_r_paren() + + return self.expression( + exp.DictProperty, + this=this, + kind=kind.this if kind else None, + settings=settings, + ) + + def _parse_dict_range(self, this: str) -> exp.DictRange: + self._match_l_paren() + has_min = self._match_text_seq("MIN") + if has_min: + min = self._parse_var() or self._parse_primary() + self._match_text_seq("MAX") + max = self._parse_var() or self._parse_primary() + else: + max = self._parse_var() or self._parse_primary() + min = exp.Literal.number(0) + self._match_r_paren() + return self.expression(exp.DictRange, this=this, min=min, max=max) + def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: diff --git a/sqlglot/serde.py b/sqlglot/serde.py index c5203a7..b019035 100644 --- a/sqlglot/serde.py +++ b/sqlglot/serde.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import expressions as exp if t.TYPE_CHECKING: - JSON = t.Union[dict, list, str, float, int, bool] + JSON = t.Union[dict, list, str, float, int, bool, None] Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON] @@ -24,12 +24,12 @@ def dump(node: Node) -> JSON: klass = node.__class__.__qualname__ if node.__class__.__module__ != exp.__name__: klass = f"{node.__module__}.{klass}" - obj = { + obj: t.Dict = { "class": klass, "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []}, } if node.type: - obj["type"] = node.type.sql() + obj["type"] = dump(node.type) if node.comments: obj["comments"] = node.comments if node._meta is not None: @@ -60,7 +60,7 @@ def load(obj: JSON) -> Node: klass = getattr(module, class_name) expression = klass(**{k: load(v) for k, v in obj["args"].items()}) - expression.type = obj.get("type") + expression.type = t.cast(exp.DataType, load(obj.get("type"))) expression.comments = obj.get("comments") expression._meta = obj.get("meta") diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index ad329d2..a30ec24 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -113,6 +113,18 @@ class TokenType(AutoName): DATETIME = auto() DATETIME64 = auto() DATE = auto() + INT4RANGE = auto() + INT4MULTIRANGE = auto() + INT8RANGE = auto() + INT8MULTIRANGE = auto() + NUMRANGE = auto() + NUMMULTIRANGE = auto() + TSRANGE = auto() + TSMULTIRANGE = auto() + TSTZRANGE = auto() + TSTZMULTIRANGE = auto() + DATERANGE = auto() + DATEMULTIRANGE = auto() UUID = auto() GEOGRAPHY = auto() NULLABLE = auto() @@ -167,6 +179,7 @@ class TokenType(AutoName): DELETE = auto() DESC = auto() DESCRIBE = auto() + DICTIONARY = auto() DISTINCT = auto() DIV = auto() DROP = auto() @@ -480,6 +493,7 @@ class Tokenizer(metaclass=_Tokenizer): "ANY": TokenType.ANY, "ASC": TokenType.ASC, "AS": TokenType.ALIAS, + "ASOF": TokenType.ASOF, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "BEGIN": TokenType.BEGIN, @@ -669,6 +683,18 @@ class Tokenizer(metaclass=_Tokenizer): "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, "DATE": TokenType.DATE, "DATETIME": TokenType.DATETIME, + "INT4RANGE": TokenType.INT4RANGE, + "INT4MULTIRANGE": TokenType.INT4MULTIRANGE, + "INT8RANGE": TokenType.INT8RANGE, + "INT8MULTIRANGE": TokenType.INT8MULTIRANGE, + "NUMRANGE": TokenType.NUMRANGE, + "NUMMULTIRANGE": TokenType.NUMMULTIRANGE, + "TSRANGE": TokenType.TSRANGE, + "TSMULTIRANGE": TokenType.TSMULTIRANGE, + "TSTZRANGE": TokenType.TSTZRANGE, + "TSTZMULTIRANGE": TokenType.TSTZMULTIRANGE, + "DATERANGE": TokenType.DATERANGE, + "DATEMULTIRANGE": TokenType.DATEMULTIRANGE, "UNIQUE": TokenType.UNIQUE, "STRUCT": TokenType.STRUCT, "VARIANT": TokenType.VARIANT, @@ -709,8 +735,6 @@ class Tokenizer(metaclass=_Tokenizer): COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")] KEYWORD_TRIE: t.Dict = {} # autofilled - IDENTIFIER_CAN_START_WITH_DIGIT = False - __slots__ = ( "sql", "size", @@ -724,6 +748,7 @@ class Tokenizer(metaclass=_Tokenizer): "_end", "_peek", "_prev_token_line", + "identifiers_can_start_with_digit", ) def __init__(self) -> None: @@ -826,6 +851,12 @@ class Tokenizer(metaclass=_Tokenizer): def _text(self) -> str: return self.sql[self._start : self._current] + def peek(self, i: int = 0) -> str: + i = self._current + i + if i < self.size: + return self.sql[i] + return "" + def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line self.tokens.append( @@ -962,8 +993,12 @@ class Tokenizer(metaclass=_Tokenizer): if self._peek.isdigit(): self._advance() elif self._peek == "." and not decimal: - decimal = True - self._advance() + after = self.peek(1) + if after.isdigit() or not after.strip(): + decimal = True + self._advance() + else: + return self._add(TokenType.VAR) elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() @@ -984,7 +1019,7 @@ class Tokenizer(metaclass=_Tokenizer): self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") return self._add(token_type, literal) - elif self.IDENTIFIER_CAN_START_WITH_DIGIT: + elif self.identifiers_can_start_with_digit: # type: ignore return self._add(TokenType.VAR) self._add(TokenType.NUMBER, number_text) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index a1ec1bd..ba72616 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -268,6 +268,17 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression return expression +def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, (exp.Cast, exp.TryCast)) + and expression.name.lower() == "epoch" + and expression.to.this in exp.DataType.TEMPORAL_TYPES + ): + expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: |