diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 52 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 32 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 27 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 15 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 29 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 13 |
11 files changed, 181 insertions, 22 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1f1f90a..432fd8c 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -33,6 +33,49 @@ def _date_add_sql(data_type, kind): return func +def _subquery_to_unnest_if_values(self, expression): + if not isinstance(expression.this, exp.Values): + return self.subquery_sql(expression) + rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)] + structs = [] + for row in rows: + aliases = [ + exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) + ] + structs.append(exp.Struct(expressions=aliases)) + unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) + return self.unnest_sql(unnest_exp) + + +def _returnsproperty_sql(self, expression): + value = expression.args.get("value") + if isinstance(value, exp.Schema): + value = f"{value.this} <{self.expressions(value)}>" + else: + value = self.sql(value) + return f"RETURNS {value}" + + +def _create_sql(self, expression): + kind = expression.args.get("kind") + returns = expression.find(exp.ReturnsProperty) + if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): + expression = expression.copy() + expression.set("kind", "TABLE FUNCTION") + if isinstance( + expression.expression, + ( + exp.Subquery, + exp.Literal, + ), + ): + expression.set("expression", expression.expression.this) + + return self.create_sql(expression) + + return self.create_sql(expression) + + class BigQuery(Dialect): unnest_column_only = True @@ -77,8 +120,14 @@ class BigQuery(Dialect): TokenType.CURRENT_TIME: exp.CurrentTime, } + NESTED_TYPE_TOKENS = { + *Parser.NESTED_TYPE_TOKENS, + TokenType.TABLE, + } + class Generator(Generator): TRANSFORMS = { + **Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), @@ -91,6 +140,9 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.VariancePop: rename_func("VAR_POP"), + exp.Subquery: _subquery_to_unnest_if_values, + exp.ReturnsProperty: _returnsproperty_sql, + exp.Create: _create_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0120e71..0ab584e 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -245,6 +245,11 @@ def no_tablesample_sql(self, expression): return self.sql(expression.this) +def no_pivot_sql(self, expression): + self.unsupported("PIVOT unsupported") + return self.sql(expression) + + def no_trycast_sql(self, expression): return self.cast_sql(expression) @@ -282,3 +287,30 @@ def format_time_lambda(exp_class, dialect, default=None): ) return _format_time + + +def create_with_partitions_sql(self, expression): + """ + In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the + PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding + columns are removed from the create statement. + """ + has_schema = isinstance(expression.this, exp.Schema) + is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") + + if has_schema and is_partitionable: + expression = expression.copy() + prop = expression.find(exp.PartitionedByProperty) + value = prop and prop.args.get("value") + if prop and not isinstance(value, exp.Schema): + schema = expression.this + columns = {v.name.upper() for v in value.expressions} + partitions = [col for col in schema.expressions if col.name.upper() in columns] + schema.set( + "expressions", + [e for e in schema.expressions if e not in partitions], + ) + prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) + expression.set("this", schema) + + return self.create_sql(expression) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4ca9e84..e09c3dd 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, format_time_lambda, + no_pivot_sql, no_safe_divide_sql, no_tablesample_sql, rename_func, @@ -122,6 +123,7 @@ class DuckDB(Dialect): exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.Pivot: no_pivot_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 59aa8fa..7a27bb3 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -2,6 +2,7 @@ from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, + create_with_partitions_sql, format_time_lambda, if_sql, no_ilike_sql, @@ -53,7 +54,7 @@ def _array_sort(self, expression): def _property_sql(self, expression): key = expression.name value = self.sql(expression, "value") - return f"'{key}' = {value}" + return f"'{key}'={value}" def _str_to_unix(self, expression): @@ -218,15 +219,6 @@ class Hive(Dialect): } class Generator(Generator): - ROOT_PROPERTIES = [ - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.SchemaCommentProperty, - exp.LocationProperty, - exp.TableFormatProperty, - ] - WITH_PROPERTIES = [exp.AnonymousProperty] - TYPE_MAPPING = { **Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", @@ -255,13 +247,13 @@ class Hive(Dialect): exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: _map_sql, HiveMap: _map_sql, - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", + exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.SafeDivide: no_safe_divide_sql, - exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}", + exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", @@ -282,6 +274,17 @@ class Hive(Dialect): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + } + + WITH_PROPERTIES = {exp.AnonymousProperty} + + ROOT_PROPERTIES = { + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.LocationProperty, + exp.TableFormatProperty, } def with_properties(self, properties): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 87a2c41..8449379 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -172,6 +172,11 @@ class MySQL(Dialect): ), } + PROPERTY_PARSERS = { + **Parser.PROPERTY_PARSERS, + TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), + } + class Generator(Generator): NULL_ORDERING_SUPPORTED = False @@ -190,3 +195,13 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, } + + ROOT_PROPERTIES = { + exp.EngineProperty, + exp.AutoIncrementProperty, + exp.CharacterSetProperty, + exp.CollateProperty, + exp.SchemaCommentProperty, + } + + WITH_PROPERTIES = {} diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c796839..aaa07a1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + str_position_sql, ) from sqlglot.generator import Generator from sqlglot.parser import Parser @@ -158,7 +159,6 @@ class Postgres(Dialect): "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, "IDENTITY": TokenType.IDENTITY, - "FOR": TokenType.FOR, "GENERATED": TokenType.GENERATED, "DOUBLE PRECISION": TokenType.DOUBLE, "BIGSERIAL": TokenType.BIGSERIAL, @@ -204,6 +204,7 @@ class Postgres(Dialect): exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), exp.Lateral: _lateral_sql, + exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 7253f7e..85647c5 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -146,13 +146,16 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") - WITH_PROPERTIES = [ + ROOT_PROPERTIES = { + exp.SchemaCommentProperty, + } + + WITH_PROPERTIES = { exp.PartitionedByProperty, exp.FileFormatProperty, - exp.SchemaCommentProperty, exp.AnonymousProperty, exp.TableFormatProperty, - ] + } TYPE_MAPPING = { **Generator.TYPE_MAPPING, @@ -184,13 +187,11 @@ class Presto(Dialect): exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", - exp.FileFormatProperty: lambda self, e: self.property_sql(e), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", exp.Quantile: _quantile_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index b5d4f0a..1b718f7 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,5 +1,10 @@ from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + inline_array_sql, + rename_func, +) from sqlglot.expressions import Literal from sqlglot.generator import Generator from sqlglot.helper import list_get @@ -104,6 +109,8 @@ class Snowflake(Dialect): "ARRAYAGG": exp.ArrayAgg.from_arg_list, "IFF": exp.If.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, + "ARRAY_CONSTRUCT": exp.Array.from_arg_list, + "RLIKE": exp.RegexpLike.from_arg_list, } FUNCTION_PARSERS = { @@ -111,6 +118,11 @@ class Snowflake(Dialect): "DATE_PART": lambda self: self._parse_extract(), } + FUNC_TOKENS = { + *Parser.FUNC_TOKENS, + TokenType.RLIKE, + } + COLUMN_OPERATORS = { **Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( @@ -120,6 +132,11 @@ class Snowflake(Dialect): ), } + PROPERTY_PARSERS = { + **Parser.PROPERTY_PARSERS, + TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), + } + class Tokenizer(Tokenizer): QUOTES = ["'", "$$"] ESCAPE = "\\" @@ -137,6 +154,7 @@ class Snowflake(Dialect): "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, + "SAMPLE": TokenType.TABLE_SAMPLE, } class Generator(Generator): @@ -145,6 +163,8 @@ class Snowflake(Dialect): exp.If: rename_func("IFF"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: _unix_to_time, + exp.Array: inline_array_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", } TYPE_MAPPING = { @@ -152,6 +172,13 @@ class Snowflake(Dialect): exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } + ROOT_PROPERTIES = { + exp.PartitionedByProperty, + exp.ReturnsProperty, + exp.LanguageProperty, + exp.SchemaCommentProperty, + } + def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index c051178..5446e83 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,5 +1,9 @@ from sqlglot import exp -from sqlglot.dialects.dialect import no_ilike_sql, rename_func +from sqlglot.dialects.dialect import ( + create_with_partitions_sql, + no_ilike_sql, + rename_func, +) from sqlglot.dialects.hive import Hive, HiveMap from sqlglot.helper import list_get @@ -10,7 +14,7 @@ def _create_sql(self, e): if kind.upper() == "TABLE" and temporary is True: return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return self.create_sql(e) + return create_with_partitions_sql(self, e) def _map_sql(self, expression): @@ -73,6 +77,7 @@ class Spark(Hive): } class Generator(Hive.Generator): + TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index b9cd584..ef8c82d 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,4 +1,5 @@ from sqlglot import exp +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.mysql import MySQL @@ -10,3 +11,12 @@ class StarRocks(MySQL): exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } + + TRANSFORMS = { + **MySQL.Generator.TRANSFORMS, + exp.DateDiff: rename_func("DATEDIFF"), + exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 68bb9bd..73b232e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,6 +1,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import Dialect from sqlglot.generator import Generator +from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType @@ -17,6 +18,7 @@ class TSQL(Dialect): "REAL": TokenType.FLOAT, "NTEXT": TokenType.TEXT, "SMALLDATETIME": TokenType.DATETIME, + "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "TIME": TokenType.TIMESTAMP, "VARBINARY": TokenType.BINARY, @@ -24,15 +26,24 @@ class TSQL(Dialect): "MONEY": TokenType.MONEY, "SMALLMONEY": TokenType.SMALLMONEY, "ROWVERSION": TokenType.ROWVERSION, - "SQL_VARIANT": TokenType.SQL_VARIANT, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "XML": TokenType.XML, + "SQL_VARIANT": TokenType.VARIANT, } + class Parser(Parser): + def _parse_convert(self): + to = self._parse_types() + self._match(TokenType.COMMA) + this = self._parse_field() + return self.expression(exp.Cast, this=this, to=to) + class Generator(Generator): TYPE_MAPPING = { **Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.DATETIME: "DATETIME2", + exp.DataType.Type.VARIANT: "SQL_VARIANT", } |