diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-25 08:20:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-25 08:20:09 +0000 |
commit | 4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6 (patch) | |
tree | 8f4f60a82ab9cd6dcd41397e4ecb2960c332b209 /sqlglot | |
parent | Releasing debian version 18.5.1-1. (diff) | |
download | sqlglot-4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6.tar.xz sqlglot-4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6.zip |
Merging upstream version 18.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
28 files changed, 982 insertions, 258 deletions
diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py index 2acbbf7..86d965a 100644 --- a/sqlglot/_typing.py +++ b/sqlglot/_typing.py @@ -4,5 +4,10 @@ import typing as t import sqlglot +# A little hack for backwards compatibility with Python 3.7. +# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed. +# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency. +A = t.TypeVar("A", bound=t.Any) + E = t.TypeVar("E", bound="sqlglot.exp.Expression") T = t.TypeVar("T") diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 3acf494..ca85376 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -212,7 +212,15 @@ class Column: return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) def alias(self, name: str) -> Column: - new_expression = exp.alias_(self.column_expression, name) + from sqlglot.dataframe.sql.session import SparkSession + + dialect = SparkSession().dialect + alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) + new_expression = exp.alias_( + self.column_expression, + alias.this if isinstance(alias, exp.Column) else name, + dialect=dialect, + ) return Column(new_expression) def asc(self) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6c71624..1349c56 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( date_add_interval_sql, datestrtodate_sql, format_time_lambda, + if_sql, inline_array_sql, json_keyvalue_comma_sql, max_or_greatest, @@ -176,6 +177,8 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False + LOG_BASE_FIRST = False # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -256,7 +259,6 @@ class BigQuery(Dialect): "RECORD": TokenType.STRUCT, "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, - "UNKNOWN": TokenType.NULL, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } KEYWORDS.pop("DIV") @@ -264,7 +266,6 @@ class BigQuery(Dialect): class Parser(parser.Parser): PREFIXED_PIVOT_COLUMNS = True - LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True FUNCTIONS = { @@ -292,9 +293,7 @@ class BigQuery(Dialect): expression=seq_get(args, 1), position=seq_get(args, 2), occurrence=seq_get(args, 3), - group=exp.Literal.number(1) - if re.compile(str(seq_get(args, 1))).groups == 1 - else None, + group=exp.Literal.number(1) if re.compile(args[1].name).groups == 1 else None, ), "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), @@ -344,6 +343,11 @@ class BigQuery(Dialect): "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()), } + RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy() + RANGE_PARSERS.pop(TokenType.OVERLAPS, None) + + NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN} + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: this = super()._parse_table_part(schema=schema) or self._parse_number() @@ -413,8 +417,8 @@ class BigQuery(Dialect): TABLE_HINTS = False LIMIT_FETCH = "LIMIT" RENAME_TABLE_WITH_DB = False - ESCAPE_LINE_BREAK = True NVL2_SUPPORTED = False + UNNEST_WITH_ORDINALITY = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -434,6 +438,7 @@ class BigQuery(Dialect): exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GroupConcat: rename_func("STRING_AGG"), exp.Hex: rename_func("TO_HEX"), + exp.If: if_sql(false_value="NULL"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), @@ -455,10 +460,11 @@ class BigQuery(Dialect): exp.ReturnsProperty: _returnsproperty_sql, exp.Select: transforms.preprocess( [ - transforms.explode_to_unnest, + transforms.explode_to_unnest(), _unqualify_unnest, transforms.eliminate_distinct_on, _alias_ordered_group, + transforms.eliminate_semi_and_anti_joins, ] ), exp.SHA2: lambda self, e: self.func( @@ -514,6 +520,18 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore + { + "\a": "\\a", + "\b": "\\b", + "\f": "\\f", + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", + "\v": "\\v", + } + ) + # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords RESERVED_KEYWORDS = { *generator.Generator.RESERVED_KEYWORDS, diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index d552f4c..7446081 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -113,15 +113,11 @@ class ClickHouse(Dialect): *parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF, - TokenType.ANTI, - TokenType.SEMI, TokenType.ARRAY, } - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.ANY, - TokenType.SEMI, - TokenType.ANTI, TokenType.SETTINGS, TokenType.FORMAT, TokenType.ARRAY, diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 6ec0487..39daad7 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -51,8 +51,6 @@ class Databricks(Spark): exp.ToChar: lambda self, e: self.function_fallback_sql(e), } - PARAMETER_TOKEN = "$" - class Tokenizer(Spark.Tokenizer): HEX_STRINGS = [] diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index d4811c5..ccf04da 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -125,6 +125,12 @@ class _Dialect(type): if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + if not klass.SUPPORTS_SEMI_ANTI_JOIN: + klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { + TokenType.ANTI, + TokenType.SEMI, + } + klass.generator_class.can_identify = klass.can_identify return klass @@ -156,9 +162,15 @@ class Dialect(metaclass=_Dialect): # Determines whether or not user-defined data types are supported SUPPORTS_USER_DEFINED_TYPES = True + # Determines whether or not SEMI/ANTI JOINs are supported + SUPPORTS_SEMI_ANTI_JOIN = True + # Determines how function names are going to be normalized NORMALIZE_FUNCTIONS: bool | str = "upper" + # Determines whether the base comes first in the LOG function + LOG_BASE_FIRST = True + # Indicates the default null ordering method to use if not explicitly set # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" NULL_ORDERING = "nulls_are_small" @@ -331,10 +343,18 @@ def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) - return self.func("APPROX_COUNT_DISTINCT", expression.this) -def if_sql(self: Generator, expression: exp.If) -> str: - return self.func( - "IF", expression.this, expression.args.get("true"), expression.args.get("false") - ) +def if_sql( + name: str = "IF", false_value: t.Optional[exp.Expression | str] = None +) -> t.Callable[[Generator, exp.If], str]: + def _if_sql(self: Generator, expression: exp.If) -> str: + return self.func( + name, + expression.this, + expression.args.get("true"), + expression.args.get("false") or false_value, + ) + + return _if_sql def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: @@ -751,6 +771,12 @@ def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: return self.func("MAX", expression.this) +def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: + a = self.sql(expression.left) + b = self.sql(expression.right) + return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" + + # Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" @@ -764,3 +790,10 @@ def is_parse_json(expression: exp.Expression) -> bool: def isnull_to_is_null(args: t.List) -> exp.Expression: return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) + + +def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: + if expression.expression.args.get("with"): + expression = expression.copy() + expression.set("with", expression.expression.args["with"].pop()) + return self.insert_sql(expression) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 87fb9b5..8b2e708 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -40,6 +40,7 @@ class Drill(Dialect): DATEINT_FORMAT = "'yyyyMMdd'" TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False TIME_MAPPING = { "y": "%Y", @@ -135,7 +136,9 @@ class Drill(Dialect): exp.StrPosition: str_position_sql, exp.StrToDate: _str_to_date, exp.Pow: rename_func("POW"), - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index ab7a26a..352f11a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, binary_from_function, + bool_xor_sql, date_trunc_to_time, datestrtodate_sql, encode_decode_sql, @@ -190,6 +191,11 @@ class DuckDB(Dialect): ), } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { + TokenType.SEMI, + TokenType.ANTI, + } + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -224,6 +230,7 @@ class DuckDB(Dialect): STRUCT_DELIMITER = ("(", ")") RENAME_TABLE_WITH_DB = False NVL2_SUPPORTED = False + SEMI_ANTI_JOIN_WITH_SIDE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -234,7 +241,7 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), - exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression), + exp.BitwiseXor: rename_func("XOR"), exp.CommentColumnConstraint: no_comment_column_constraint_sql, exp.CurrentDate: lambda self, e: "CURRENT_DATE", exp.CurrentTime: lambda self, e: "CURRENT_TIME", @@ -301,6 +308,7 @@ class DuckDB(Dialect): exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", exp.VariancePop: rename_func("VAR_POP"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.Xor: bool_xor_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index bec27d3..a427870 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -111,7 +111,7 @@ def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: def _property_sql(self: Hive.Generator, expression: exp.Property) -> str: - return f"'{expression.name}'={self.sql(expression, 'value')}" + return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str: @@ -413,7 +413,7 @@ class Hive(Dialect): exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), - exp.If: if_sql, + exp.If: if_sql(), exp.ILike: no_ilike_sql, exp.IsNan: rename_func("ISNAN"), exp.JSONExtract: rename_func("GET_JSON_OBJECT"), @@ -466,6 +466,11 @@ class Hive(Dialect): exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.LastDateOfMonth: rename_func("LAST_DAY"), exp.National: lambda self, e: self.national_sql(e, prefix=""), + exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", + exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", + exp.NotForReplicationColumnConstraint: lambda self, e: "", + exp.OnProperty: lambda self, e: "", + exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY", } PROPERTIES_LOCATION = { @@ -475,6 +480,35 @@ class Hive(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def parameter_sql(self, expression: exp.Parameter) -> str: + this = self.sql(expression, "this") + parent = expression.parent + + if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem): + # We need to produce SET key = value instead of SET ${key} = value + return this + + return f"${{{this}}}" + + def schema_sql(self, expression: exp.Schema) -> str: + expression = expression.copy() + + for ordered in expression.find_all(exp.Ordered): + if ordered.args.get("desc") is False: + ordered.set("desc", None) + + return super().schema_sql(expression) + + def constraint_sql(self, expression: exp.Constraint) -> str: + expression = expression.copy() + + for prop in list(expression.find_all(exp.Properties)): + prop.pop() + + this = self.sql(expression, "this") + expressions = self.expressions(expression, sep=" ", flat=True) + return f"CONSTRAINT {this} {expressions}" + def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str: serde_props = self.sql(expression, "serde_properties") serde_props = f" {serde_props}" if serde_props else "" diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 75660f8..554241d 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -102,6 +102,7 @@ class MySQL(Dialect): TIME_FORMAT = "'%Y-%m-%d %T'" DPIPE_IS_STRING_CONCAT = False SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions TIME_MAPPING = { @@ -519,7 +520,7 @@ class MySQL(Dialect): return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") - def _parse_type(self) -> t.Optional[exp.Expression]: + def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: # mysql binary is special and can work anywhere, even in order by operations # it operates like a no paren func if self._match(TokenType.BINARY, advance=False): @@ -528,7 +529,7 @@ class MySQL(Dialect): if isinstance(data_type, exp.DataType): return self.expression(exp.Cast, this=self._parse_column(), to=data_type) - return super()._parse_type() + return super()._parse_type(parse_interval=parse_interval) class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -560,7 +561,9 @@ class MySQL(Dialect): exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.Pivot: no_pivot_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index d049d8e..342fd95 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + bool_xor_sql, datestrtodate_sql, format_time_lambda, max_or_greatest, @@ -110,7 +111,7 @@ def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> st def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): - return f"{self.expressions(expression, flat=True)}[]" + return f"{self.expressions(expression, flat=True)}[]" if expression.expressions else "ARRAY" return self.datatype_sql(expression) @@ -380,25 +381,29 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.DateAdd: _date_add_sql("+"), + exp.DateDiff: _date_diff_sql, + exp.DateStrToDate: datestrtodate_sql, + exp.DataType: _datatype_sql, + exp.DateSub: _date_add_sql("-"), exp.Explode: rename_func("UNNEST"), + exp.GroupConcat: _string_agg_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), - exp.Pow: lambda self, e: self.binary(e, "^"), - exp.CurrentDate: no_paren_current_date_sql, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.DateAdd: _date_add_sql("+"), - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("-"), - exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Max: max_or_greatest, @@ -412,8 +417,10 @@ class Postgres(Dialect): [transforms.add_within_group_for_percentiles] ), exp.Pivot: no_pivot_sql, + exp.Pow: lambda self, e: self.binary(e, "^"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), + exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, @@ -426,11 +433,7 @@ class Postgres(Dialect): exp.TryCast: no_trycast_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", - exp.DataType: _datatype_sql, - exp.GroupConcat: _string_agg_sql, - exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" - if isinstance(seq_get(e.expressions, 0), exp.Select) - else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", + exp.Xor: bool_xor_sql, } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9ae4c32..0d8d4ab 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, binary_from_function, + bool_xor_sql, date_trunc_to_time, encode_decode_sql, format_time_lambda, @@ -40,7 +41,7 @@ def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> s this=exp.Unnest( expressions=[expression.this.this], alias=expression.args.get("alias"), - ordinality=isinstance(expression.this, exp.Posexplode), + offset=isinstance(expression.this, exp.Posexplode), ), kind="cross", ) @@ -173,6 +174,7 @@ class Presto(Dialect): TIME_FORMAT = MySQL.TIME_FORMAT TIME_MAPPING = MySQL.TIME_MAPPING STRICT_STRING_CONCAT = True + SUPPORTS_SEMI_ANTI_JOIN = False # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -308,7 +310,7 @@ class Presto(Dialect): exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), - exp.If: if_sql, + exp.If: if_sql(), exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, exp.ParseJSON: rename_func("JSON_PARSE"), @@ -331,7 +333,8 @@ class Presto(Dialect): [ transforms.eliminate_qualify, transforms.eliminate_distinct_on, - transforms.explode_to_unnest, + transforms.explode_to_unnest(1), + transforms.eliminate_semi_and_anti_joins, ] ), exp.SortArray: _no_sort_array, @@ -340,7 +343,6 @@ class Presto(Dialect): exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", - exp.Struct: rename_func("ROW"), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, @@ -363,8 +365,16 @@ class Presto(Dialect): [transforms.remove_within_group_for_percentiles] ), exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]), + exp.Xor: bool_xor_sql, } + def struct_sql(self, expression: exp.Struct) -> str: + if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions): + self.unsupported("Struct with key-value definitions is unsupported.") + return self.function_fallback_sql(expression) + + return rename_func("ROW")(self, expression) + def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") if expression.this and unit.lower().startswith("week"): diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b4c7664..2145844 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -138,7 +138,9 @@ class Redshift(Postgres): exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, exp.SafeConcat: concat_to_dpipe_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5aa946e..5c49331 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -5,9 +5,11 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + binary_from_function, date_trunc_to_time, datestrtodate_sql, format_time_lambda, + if_sql, inline_array_sql, max_or_greatest, min_or_least, @@ -203,6 +205,7 @@ class Snowflake(Dialect): NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False TIME_MAPPING = { "YYYY": "%Y", @@ -240,7 +243,16 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, + "ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries( + # ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive + start=seq_get(args, 0), + end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)), + step=seq_get(args, 2), + ), "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "BITXOR": binary_from_function(exp.BitwiseXor), + "BIT_XOR": binary_from_function(exp.BitwiseXor), + "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( @@ -277,7 +289,7 @@ class Snowflake(Dialect): ), } - TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME} + TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME} RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, @@ -381,6 +393,7 @@ class Snowflake(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + AGGREGATE_FILTER_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -390,6 +403,7 @@ class Snowflake(Dialect): exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), + exp.BitwiseXor: rename_func("BITXOR"), exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), exp.DateDiff: lambda self, e: self.func( "DATEDIFF", e.text("unit"), e.expression, e.this @@ -398,8 +412,11 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.Extract: rename_func("DATE_PART"), + exp.GenerateSeries: lambda self, e: self.func( + "ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step") + ), exp.GroupConcat: rename_func("LISTAGG"), - exp.If: rename_func("IFF"), + exp.If: if_sql(name="IFF", false_value="NULL"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -407,7 +424,13 @@ class Snowflake(Dialect): exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.RegexpILike: _regexpilike_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_distinct_on, + transforms.explode_to_unnest(0), + transforms.eliminate_semi_and_anti_joins, + ] + ), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), exp.StrPosition: lambda self, e: self.func( @@ -431,6 +454,7 @@ class Snowflake(Dialect): exp.UnixToTime: _unix_to_time_sql, exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.Xor: rename_func("BOOLXOR"), } TYPE_MAPPING = { @@ -449,6 +473,27 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def unnest_sql(self, expression: exp.Unnest) -> str: + selects = ["value"] + unnest_alias = expression.args.get("alias") + + offset = expression.args.get("offset") + if offset: + if unnest_alias: + expression = expression.copy() + unnest_alias.append("columns", offset.pop()) + + selects.append("index") + + subquery = exp.Subquery( + this=exp.select(*selects).from_( + f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))" + ), + ) + alias = self.sql(unnest_alias) + alias = f" AS {alias}" if alias else "" + return f"{self.sql(subquery)}{alias}" + def show_sql(self, expression: exp.Show) -> str: scope = self.sql(expression, "scope") scope = f" {scope}" if scope else "" diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 56d33ba..3dc9838 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, format_time_lambda, is_parse_json, + move_insert_cte_sql, pivot_column_names, rename_func, trim_sql, @@ -115,13 +116,6 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression -def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str: - if expression.expression.args.get("with"): - expression = expression.copy() - expression.set("with", expression.expression.args.pop("with")) - return self.insert_sql(expression) - - class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { @@ -206,7 +200,7 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.Insert: _insert_sql, + exp.Insert: move_insert_cte_sql, exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 7bfdf1c..1edfa9d 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -64,6 +64,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: class SQLite(Dialect): # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + SUPPORTS_SEMI_ANTI_JOIN = False class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -125,7 +126,11 @@ class SQLite(Dialect): exp.Pivot: no_pivot_sql, exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_qualify] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_qualify, + transforms.eliminate_semi_and_anti_joins, + ] ), exp.TableSample: no_tablesample_sql, exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index d9de968..b9e925a 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -8,6 +8,8 @@ from sqlglot.tokens import TokenType class Teradata(Dialect): + SUPPORTS_SEMI_ANTI_JOIN = False + TIME_MAPPING = { "Y": "%Y", "YYYY": "%Y", @@ -168,7 +170,9 @@ class Teradata(Dialect): **generator.Generator.TRANSFORMS, exp.Max: max_or_greatest, exp.Min: min_or_least, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 2299310..fa62e78 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( any_value_to_max_sql, max_or_greatest, min_or_least, + move_insert_cte_sql, parse_date_delta, rename_func, timestrtotime_sql, @@ -206,6 +207,8 @@ class TSQL(Dialect): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" + SUPPORTS_SEMI_ANTI_JOIN = False + LOG_BASE_FIRST = False TIME_MAPPING = { "year": "%Y", @@ -345,6 +348,8 @@ class TSQL(Dialect): } class Parser(parser.Parser): + SET_REQUIRES_ASSIGNMENT_DELIMITER = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( @@ -396,7 +401,6 @@ class TSQL(Dialect): TokenType.END: lambda self: self._parse_command(), } - LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True CONCAT_NULL_OUTPUTS_STRING = True @@ -609,11 +613,14 @@ class TSQL(Dialect): exp.Extract: rename_func("DATEPART"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.Insert: move_insert_cte_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, exp.NumberToStr: _format_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( "HASHBYTES", @@ -632,6 +639,14 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def setitem_sql(self, expression: exp.SetItem) -> str: + this = expression.this + if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter): + # T-SQL does not use '=' in SET command, except when the LHS is a variable. + return f"{self.sql(this.left)} {self.sql(this.right)}" + + return super().setitem_sql(expression) + def boolean_sql(self, expression: exp.Boolean) -> str: if type(expression.parent) in BIT_TYPES: return "1" if expression.this else "0" @@ -661,16 +676,27 @@ class TSQL(Dialect): exists = expression.args.pop("exists", None) sql = super().create_sql(expression) + table = expression.find(exp.Table) + + if kind == "TABLE" and expression.expression: + sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp" + if exists: - table = expression.find(exp.Table) identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) + sql = self.sql(exp.Literal.string(sql)) if kind == "SCHEMA": - sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')""" + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC({sql})""" elif kind == "TABLE": - sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')""" + assert table + where = exp.and_( + exp.column("table_name").eq(table.name), + exp.column("table_schema").eq(table.db) if table.db else None, + exp.column("table_catalog").eq(table.catalog) if table.catalog else None, + ) + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE {where}) EXEC({sql})""" elif kind == "INDEX": index = self.sql(exp.Literal.string(expression.this.text("this"))) - sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')""" + sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC({sql})""" elif expression.args.get("replace"): sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1c3d42a..8e9575e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -664,16 +664,6 @@ class Expression(metaclass=_Expression): return load(obj) - -IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], -] -ExpOrStr = t.Union[str, Expression] - - -class Condition(Expression): def and_( self, *expressions: t.Optional[ExpOrStr], @@ -762,11 +752,19 @@ class Condition(Expression): return klass(this=other, expression=this) return klass(this=this, expression=other) - def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]): + def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]) -> Bracket: return Bracket( this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)] ) + def __iter__(self) -> t.Iterator: + if "expressions" in self.arg_types: + return iter(self.args.get("expressions") or []) + # We define this because __getitem__ converts Expression into an iterable, which is + # problematic because one can hit infinite loops if they do "for x in some_expr: ..." + # See: https://peps.python.org/pep-0234/ + raise TypeError(f"'{self.__class__.__name__}' object is not iterable") + def isin( self, *expressions: t.Any, @@ -886,6 +884,18 @@ class Condition(Expression): return not_(self.copy()) +IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], +] +ExpOrStr = t.Union[str, Expression] + + +class Condition(Expression): + """Logical conditions like x AND y, or simply x""" + + class Predicate(Condition): """Relationships like x = y, x > 1, x >= y.""" @@ -1045,6 +1055,10 @@ class Describe(Expression): arg_types = {"this": True, "kind": False, "expressions": False} +class Kill(Expression): + arg_types = {"this": True, "kind": False} + + class Pragma(Expression): pass @@ -1161,7 +1175,7 @@ class Column(Condition): if self.args.get(part) ] - def to_dot(self) -> Dot: + def to_dot(self) -> Dot | Identifier: """Converts the column into a dot expression.""" parts = self.parts parent = self.parent @@ -1171,7 +1185,7 @@ class Column(Condition): parts.append(parent.expression) parent = parent.parent - return Dot.build(deepcopy(parts)) + return Dot.build(deepcopy(parts)) if len(parts) > 1 else parts[0] class ColumnPosition(Expression): @@ -1607,6 +1621,7 @@ class Index(Expression): "primary": False, "amp": False, # teradata "partition_by": False, # teradata + "where": False, # postgres partial indexes } @@ -1917,7 +1932,7 @@ class Sort(Order): class Ordered(Expression): - arg_types = {"this": True, "desc": True, "nulls_first": True} + arg_types = {"this": True, "desc": False, "nulls_first": True} class Property(Expression): @@ -2569,7 +2584,6 @@ class Intersect(Union): class Unnest(UDTF): arg_types = { "expressions": True, - "ordinality": False, "alias": False, "offset": False, } @@ -2862,6 +2876,7 @@ class Select(Subqueryable): prefix="LIMIT", dialect=dialect, copy=copy, + into_arg="expression", **opts, ) @@ -4007,6 +4022,10 @@ class TimeUnit(Expression): super().__init__(**args) + @property + def unit(self) -> t.Optional[Var]: + return self.args.get("unit") + # https://www.oracletutorial.com/oracle-basics/oracle-interval/ # https://trino.io/docs/current/language/types.html#interval-day-to-second @@ -4018,10 +4037,6 @@ class IntervalSpan(Expression): class Interval(TimeUnit): arg_types = {"this": False, "unit": False} - @property - def unit(self) -> t.Optional[Var]: - return self.args.get("unit") - class IgnoreNulls(Expression): pass @@ -4327,6 +4342,10 @@ class DateDiff(Func, TimeUnit): class DateTrunc(Func): arg_types = {"unit": True, "this": True, "zone": False} + @property + def unit(self) -> Expression: + return self.args["unit"] + class DatetimeAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -4427,7 +4446,8 @@ class DateToDi(Func): # https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date class Date(Func): - arg_types = {"this": True, "zone": False} + arg_types = {"this": False, "zone": False, "expressions": False} + is_var_len_args = True class Day(Func): @@ -5131,10 +5151,11 @@ def _apply_builder( prefix=None, into=None, dialect=None, + into_arg="this", **opts, ): if _is_wrong_expression(expression, into): - expression = into(this=expression) + expression = into(**{into_arg: expression}) instance = maybe_copy(instance, copy) expression = maybe_parse( sql_or_expression=expression, @@ -5926,7 +5947,10 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca The new Cast instance. """ expression = maybe_parse(expression, **opts) - return Cast(this=expression, to=DataType.build(to, **opts)) + data_type = DataType.build(to, **opts) + expression = Cast(this=expression, to=data_type) + expression.type = data_type + return expression def table_( diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d086e8a..b1ee783 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import typing as t from collections import defaultdict +from functools import reduce from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages @@ -99,6 +100,9 @@ class Generator: exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", } + # Whether the base comes first + LOG_BASE_FIRST = True + # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True @@ -188,6 +192,18 @@ class Generator: # Whether or not the word COLUMN is included when adding a column with ALTER TABLE ALTER_TABLE_ADD_COLUMN_KEYWORD = True + # UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery) + UNNEST_WITH_ORDINALITY = True + + # Whether or not FILTER (WHERE cond) can be used for conditional aggregation + AGGREGATE_FILTER_SUPPORTED = True + + # Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds + SEMI_ANTI_JOIN_WITH_SIDE = True + + # Whether or not session variables / parameters are supported, e.g. @x in T-SQL + SUPPORTS_PARAMETERS = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -308,6 +324,8 @@ class Generator: exp.Paren, ) + UNESCAPED_SEQUENCE_TABLE = None # type: ignore + SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" # Autofilled @@ -320,7 +338,6 @@ class Generator: STRICT_STRING_CONCAT = False NORMALIZE_FUNCTIONS: bool | str = "upper" NULL_ORDERING = "nulls_are_small" - ESCAPE_LINE_BREAK = False can_identify: t.Callable[[str, str | bool], bool] @@ -955,9 +972,16 @@ class Generator: return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}" def filter_sql(self, expression: exp.Filter) -> str: - this = self.sql(expression, "this") - where = self.sql(expression, "expression").strip() - return f"{this} FILTER({where})" + if self.AGGREGATE_FILTER_SUPPORTED: + this = self.sql(expression, "this") + where = self.sql(expression, "expression").strip() + return f"{this} FILTER({where})" + + agg = expression.this.copy() + agg_arg = agg.this + cond = expression.expression.this + agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy())) + return self.sql(agg) def hint_sql(self, expression: exp.Hint) -> str: if not self.QUERY_HINTS: @@ -975,13 +999,14 @@ class Generator: table = self.sql(expression, "table") table = f"{self.INDEX_ON} {table}" if table else "" using = self.sql(expression, "using") - using = f" USING {using} " if using else "" + using = f" USING {using}" if using else "" index = "INDEX " if not table else "" columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" partition_by = self.expressions(expression, key="partition_by", flat=True) partition_by = f" PARTITION BY {partition_by}" if partition_by else "" - return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}" + where = self.sql(expression, "where") + return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}{where}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name @@ -1060,10 +1085,15 @@ class Generator: return properties_locs + def property_name(self, expression: exp.Property, string_key: bool = False) -> str: + if isinstance(expression.this, exp.Dot): + return self.sql(expression, "this") + return f"'{expression.name}'" if string_key else expression.name + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: - return f"{expression.name}={self.sql(expression, 'value')}" + return f"{self.property_name(expression)}={self.sql(expression, 'value')}" property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) if not property_name: @@ -1224,6 +1254,13 @@ class Generator: def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def kill_sql(self, expression: exp.Kill) -> str: + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"KILL{kind}{this}" + def pseudotype_sql(self, expression: exp.PseudoType) -> str: return expression.name.upper() @@ -1386,13 +1423,11 @@ class Generator: return f"{values} AS {alias}" if alias else values # Converts `VALUES...` expression into a series of select unions. - # Note: If you have a lot of unions then this will result in a large number of recursive statements to - # evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be - # very slow. expression = expression.copy() - column_names = expression.alias and expression.args["alias"].columns + alias_node = expression.args.get("alias") + column_names = alias_node and alias_node.columns - selects = [] + selects: t.List[exp.Subqueryable] = [] for i, tup in enumerate(expression.expressions): row = tup.expressions @@ -1404,14 +1439,18 @@ class Generator: selects.append(exp.Select(expressions=row)) - subquery_expression: exp.Select | exp.Union = selects[0] - if len(selects) > 1: - for select in selects[1:]: - subquery_expression = exp.union( - subquery_expression, select, distinct=False, copy=False - ) + if self.pretty: + # This may result in poor performance for large-cardinality `VALUES` tables, due to + # the deep nesting of the resulting exp.Unions. If this is a problem, either increase + # `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`. + subqueryable = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects) + return self.subquery_sql( + subqueryable.subquery(alias_node and alias_node.this, copy=False) + ) - return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False)) + alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else "" + unions = " UNION ALL ".join(self.sql(select) for select in selects) + return f"({unions}){alias}" def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") @@ -1477,12 +1516,17 @@ class Generator: return f"PRIOR {self.sql(expression, 'this')}" def join_sql(self, expression: exp.Join) -> str: + if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ("SEMI", "ANTI"): + side = None + else: + side = expression.side + op_sql = " ".join( op for op in ( expression.method, "GLOBAL" if expression.args.get("global") else None, - expression.side, + side, expression.kind, expression.hint if self.JOIN_HINTS else None, ) @@ -1594,8 +1638,8 @@ class Generator: def escape_str(self, text: str) -> str: text = text.replace(self.QUOTE_END, self._escaped_quote_end) - if self.ESCAPE_LINE_BREAK: - text = text.replace("\n", "\\n") + if self.UNESCAPED_SEQUENCE_TABLE: + text = text.translate(self.UNESCAPED_SEQUENCE_TABLE) elif self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) return text @@ -1643,7 +1687,7 @@ class Generator: nulls_are_small = self.NULL_ORDERING == "nulls_are_small" nulls_are_last = self.NULL_ORDERING == "nulls_are_last" - sort_order = " DESC" if desc else "" + sort_order = " DESC" if desc else (" ASC" if desc is False else "") nulls_sort_change = "" if nulls_first and ( (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last @@ -1817,8 +1861,7 @@ class Generator: def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") - this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}" - return f"{self.PARAMETER_TOKEN}{this}" + return f"{self.PARAMETER_TOKEN}{this}" if self.SUPPORTS_PARAMETERS else this def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") @@ -1858,17 +1901,33 @@ class Generator: def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) + alias = expression.args.get("alias") + offset = expression.args.get("offset") + + if self.UNNEST_WITH_ORDINALITY: + if alias and isinstance(offset, exp.Expression): + alias = alias.copy() + alias.append("columns", offset.copy()) + if alias and self.UNNEST_COLUMN_ONLY: columns = alias.columns alias = self.sql(columns[0]) if columns else "" else: - alias = self.sql(expression, "alias") + alias = self.sql(alias) + alias = f" AS {alias}" if alias else alias - ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" - offset = expression.args.get("offset") - offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else "" - return f"UNNEST({args}){ordinality}{alias}{offset}" + if self.UNNEST_WITH_ORDINALITY: + suffix = f" WITH ORDINALITY{alias}" if offset else alias + else: + if isinstance(offset, exp.Expression): + suffix = f"{alias} WITH OFFSET AS {self.sql(offset)}" + elif offset: + suffix = f"{alias} WITH OFFSET" + else: + suffix = alias + + return f"UNNEST({args}){suffix}" def where_sql(self, expression: exp.Where) -> str: this = self.indent(self.sql(expression, "this")) @@ -2471,6 +2530,12 @@ class Generator: def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="TRY_") + def log_sql(self, expression: exp.Log) -> str: + args = list(expression.args.values()) + if not self.LOG_BASE_FIRST: + args.reverse() + return self.func("LOG", *args) + def use_sql(self, expression: exp.Use) -> str: kind = self.sql(expression, "kind") kind = f" {kind}" if kind else "" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 7335d1e..00d49ae 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -13,9 +13,10 @@ from itertools import count if t.TYPE_CHECKING: from sqlglot import exp - from sqlglot._typing import E, T + from sqlglot._typing import A, E, T from sqlglot.expressions import Expression + CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") PYTHON_VERSION = sys.version_info[:2] logger = logging.getLogger("sqlglot") @@ -379,7 +380,9 @@ def is_iterable(value: t.Any) -> bool: Returns: A `bool` value indicating if it is an iterable. """ - return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) + from sqlglot import Expression + + return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression)) def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: @@ -435,3 +438,22 @@ def dict_depth(d: t.Dict) -> int: def first(it: t.Iterable[T]) -> T: """Returns the first element from an iterable (useful for sets).""" return next(i for i in it) + + +def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: + if not ranges: + return [] + + ranges = sorted(ranges) + + merged = [ranges[0]] + + for start, end in ranges[1:]: + last_start, last_end = merged[-1] + + if start <= last_end: + merged[-1] = (last_start, max(last_end, end)) + else: + merged.append((start, end)) + + return merged diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index a429655..afc6995 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.TimeAdd, exp.TimeStrToTime, exp.TimeSub, + exp.Timestamp, exp.TimestampAdd, exp.TimestampSub, exp.UnixToTime, @@ -177,6 +178,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Initcap, exp.Lower, exp.SafeConcat, + exp.SafeDPipe, exp.Substring, exp.TimeToStr, exp.TimeToTimeStr, @@ -242,6 +244,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO + # Caches the ids of annotated sub-Expressions, to ensure we only visit them once + self._visited: t.Set[int] = set() + + def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: + expression.type = target_type + self._visited.add(id(expression)) + def annotate(self, expression: E) -> E: for scope in traverse_scope(expression): selects = {} @@ -279,9 +288,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator): source = scope.sources.get(col.table) if isinstance(source, exp.Table): - col.type = self.schema.get_column_type(source, col) + self._set_type(col, self.schema.get_column_type(source, col)) elif source and col.table in selects and col.name in selects[col.table]: - col.type = selects[col.table][col.name].type + self._set_type(col, selects[col.table][col.name].type) # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -289,7 +298,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return self._maybe_annotate(expression) # This takes care of non-traversable expressions def _maybe_annotate(self, expression: E) -> E: - if expression.type: + if id(expression) in self._visited: return expression # We've already inferred the expression's type annotator = self.annotators.get(expression.__class__) @@ -338,17 +347,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator): if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: - expression.type = exp.DataType.Type.NULL + self._set_type(expression, exp.DataType.Type.NULL) elif exp.DataType.Type.NULL in (left_type, right_type): - expression.type = exp.DataType.build( - "NULLABLE", expressions=exp.DataType.build("BOOLEAN") + self._set_type( + expression, + exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), ) else: - expression.type = exp.DataType.Type.BOOLEAN + self._set_type(expression, exp.DataType.Type.BOOLEAN) elif isinstance(expression, exp.Predicate): - expression.type = exp.DataType.Type.BOOLEAN + self._set_type(expression, exp.DataType.Type.BOOLEAN) else: - expression.type = self._maybe_coerce(left_type, right_type) + self._set_type(expression, self._maybe_coerce(left_type, right_type)) return expression @@ -357,26 +367,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._annotate_args(expression) if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): - expression.type = exp.DataType.Type.BOOLEAN + self._set_type(expression, exp.DataType.Type.BOOLEAN) else: - expression.type = expression.this.type + self._set_type(expression, expression.this.type) return expression @t.no_type_check def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: if expression.is_string: - expression.type = exp.DataType.Type.VARCHAR + self._set_type(expression, exp.DataType.Type.VARCHAR) elif expression.is_int: - expression.type = exp.DataType.Type.INT + self._set_type(expression, exp.DataType.Type.INT) else: - expression.type = exp.DataType.Type.DOUBLE + self._set_type(expression, exp.DataType.Type.DOUBLE) return expression @t.no_type_check def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: - expression.type = target_type + self._set_type(expression, target_type) return self._annotate_args(expression) @t.no_type_check @@ -394,17 +404,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator): for expr in expressions: last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) - expression.type = last_datatype or exp.DataType.Type.UNKNOWN + self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) if promote: if expression.type.this in exp.DataType.INTEGER_TYPES: - expression.type = exp.DataType.Type.BIGINT + self._set_type(expression, exp.DataType.Type.BIGINT) elif expression.type.this in exp.DataType.FLOAT_TYPES: - expression.type = exp.DataType.Type.DOUBLE + self._set_type(expression, exp.DataType.Type.DOUBLE) if array: - expression.type = exp.DataType( - this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True + self._set_type( + expression, + exp.DataType( + this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True + ), ) return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 015b06a..e45d1e3 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -17,9 +17,11 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: exp.replace_children(expression, canonicalize) expression = add_text_to_concat(expression) + expression = replace_date_funcs(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) expression = ensure_bool_predicates(expression) + expression = remove_ascending_order(expression) return expression @@ -30,6 +32,14 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression: return node +def replace_date_funcs(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): + return exp.cast(node.this, to=exp.DataType.Type.DATE) + if isinstance(node, exp.Timestamp) and not node.expression: + return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP) + return node + + def coerce_type(node: exp.Expression) -> exp.Expression: if isinstance(node, exp.Binary): _coerce_date(node.left, node.right) @@ -63,6 +73,14 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: return expression +def remove_ascending_order(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: + # Convert ORDER BY a ASC to ORDER BY a + expression.set("desc", None) + + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): if ( @@ -75,10 +93,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: def _replace_cast(node: exp.Expression, to: str) -> None: - data_type = exp.DataType.build(to) - cast = exp.Cast(this=node.copy(), to=data_type) - cast.type = data_type - node.replace(cast) + node.replace(exp.cast(node.copy(), to=to)) def _replace_int_predicate(expression: exp.Expression) -> None: diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 3974ea4..d08c692 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1,17 +1,22 @@ import datetime import functools import itertools +import typing as t from collections import deque from decimal import Decimal from sqlglot import exp from sqlglot.generator import cached_generator -from sqlglot.helper import first, while_changing +from sqlglot.helper import first, merge_ranges, while_changing # Final means that an expression should not be simplified FINAL = "final" +class UnsupportedUnit(Exception): + pass + + def simplify(expression): """ Rewrite sqlglot AST to simplify expressions. @@ -72,7 +77,9 @@ def simplify(expression): node = simplify_coalesce(node) node.parent = expression.parent node = simplify_literals(node, root) + node = simplify_equality(node) node = simplify_parens(node) + node = simplify_datetrunc_predicate(node) if root: expression.replace(node) @@ -84,6 +91,21 @@ def simplify(expression): return expression +def catch(*exceptions): + """Decorator that ignores a simplification function if any of `exceptions` are raised""" + + def decorator(func): + def wrapped(expression, *args, **kwargs): + try: + return func(expression, *args, **kwargs) + except exceptions: + return expression + + return wrapped + + return decorator + + def rewrite_between(expression: exp.Expression) -> exp.Expression: """Rewrite x between y and z to x >= y AND x <= z. @@ -196,7 +218,7 @@ COMPARISONS = ( exp.Is, ) -INVERSE_COMPARISONS = { +INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.LT: exp.GT, exp.GT: exp.LT, exp.LTE: exp.GTE, @@ -347,6 +369,87 @@ def absorb_and_eliminate(expression, root=True): return expression +INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + exp.DateAdd: exp.Sub, + exp.DateSub: exp.Add, + exp.DatetimeAdd: exp.Sub, + exp.DatetimeSub: exp.Add, +} + +INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + **INVERSE_DATE_OPS, + exp.Add: exp.Sub, + exp.Sub: exp.Add, +} + + +def _is_number(expression: exp.Expression) -> bool: + return expression.is_number + + +def _is_date(expression: exp.Expression) -> bool: + return isinstance(expression, exp.Cast) and extract_date(expression) is not None + + +def _is_interval(expression: exp.Expression) -> bool: + return isinstance(expression, exp.Interval) and extract_interval(expression) is not None + + +@catch(ModuleNotFoundError, UnsupportedUnit) +def simplify_equality(expression: exp.Expression) -> exp.Expression: + """ + Use the subtraction and addition properties of equality to simplify expressions: + + x + 1 = 3 becomes x = 2 + + There are two binary operations in the above expression: + and = + Here's how we reference all the operands in the code below: + + l r + x + 1 = 3 + a b + """ + if isinstance(expression, COMPARISONS): + l, r = expression.left, expression.right + + if l.__class__ in INVERSE_OPS: + pass + elif r.__class__ in INVERSE_OPS: + l, r = r, l + else: + return expression + + if r.is_number: + a_predicate = _is_number + b_predicate = _is_number + elif _is_date(r): + a_predicate = _is_date + b_predicate = _is_interval + else: + return expression + + if l.__class__ in INVERSE_DATE_OPS: + a = l.this + b = exp.Interval( + this=l.expression.copy(), + unit=l.unit.copy(), + ) + else: + a, b = l.left, l.right + + if not a_predicate(a) and b_predicate(b): + pass + elif not a_predicate(b) and b_predicate(a): + a, b = b, a + else: + return expression + + return expression.__class__( + this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) + ) + return expression + + def simplify_literals(expression, root=True): if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): return _flat_simplify(expression, _simplify_binary, root) @@ -530,6 +633,123 @@ def simplify_concat(expression): return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) +DateRange = t.Tuple[datetime.date, datetime.date] + + +def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: + """ + Get the date range for a DATE_TRUNC equality comparison: + + Example: + _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) + Returns: + tuple of [min, max) or None if a value can never be equal to `date` for `unit` + """ + floor = date_floor(date, unit) + + if date != floor: + # This will always be False, except for NULL values. + return None + + return floor, floor + interval(unit) + + +def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: + """Get the logical expression for a date range""" + return exp.and_( + left >= date_literal(drange[0]), + left < date_literal(drange[1]), + copy=False, + ) + + +def _datetrunc_eq( + left: exp.Expression, date: datetime.date, unit: str +) -> t.Optional[exp.Expression]: + drange = _datetrunc_range(date, unit) + if not drange: + return None + + return _datetrunc_eq_expression(left, drange) + + +def _datetrunc_neq( + left: exp.Expression, date: datetime.date, unit: str +) -> t.Optional[exp.Expression]: + drange = _datetrunc_range(date, unit) + if not drange: + return None + + return exp.and_( + left < date_literal(drange[0]), + left >= date_literal(drange[1]), + copy=False, + ) + + +DateTruncBinaryTransform = t.Callable[ + [exp.Expression, datetime.date, str], t.Optional[exp.Expression] +] +DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { + exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), + exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), + exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), + exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), + exp.EQ: _datetrunc_eq, + exp.NEQ: _datetrunc_neq, +} +DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} + + +def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: + return ( + isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) + and isinstance(right, exp.Cast) + and right.is_type(*exp.DataType.TEMPORAL_TYPES) + ) + + +@catch(ModuleNotFoundError, UnsupportedUnit) +def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: + """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" + comparison = expression.__class__ + + if comparison not in DATETRUNC_COMPARISONS: + return expression + + if isinstance(expression, exp.Binary): + l, r = expression.left, expression.right + + if _is_datetrunc_predicate(l, r): + pass + elif _is_datetrunc_predicate(r, l): + comparison = INVERSE_COMPARISONS.get(comparison, comparison) + l, r = r, l + else: + return expression + + unit = l.unit.name.lower() + date = extract_date(r) + + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression + elif isinstance(expression, exp.In): + l = expression.this + rs = expression.expressions + + if all(_is_datetrunc_predicate(l, r) for r in rs): + unit = l.unit.name.lower() + + ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r] + if not ranges: + return expression + + ranges = merge_ranges(ranges) + + return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) + + return expression + + # CROSS joins result in an empty table if the right table is empty. # So we can only simplify certain types of joins to CROSS. # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x @@ -603,31 +823,76 @@ def extract_date(cast): return None -def extract_interval(interval): +def extract_interval(expression): + n = int(expression.name) + unit = expression.text("unit").lower() + try: - from dateutil.relativedelta import relativedelta # type: ignore - except ModuleNotFoundError: + return interval(unit, n) + except (UnsupportedUnit, ModuleNotFoundError): return None - n = int(interval.name) - unit = interval.text("unit").lower() + +def date_literal(date): + return exp.cast( + exp.Literal.string(date), + "DATETIME" if isinstance(date, datetime.datetime) else "DATE", + ) + + +def interval(unit: str, n: int = 1): + from dateutil.relativedelta import relativedelta if unit == "year": - return relativedelta(years=n) + return relativedelta(years=1 * n) + if unit == "quarter": + return relativedelta(months=3 * n) if unit == "month": - return relativedelta(months=n) + return relativedelta(months=1 * n) if unit == "week": - return relativedelta(weeks=n) + return relativedelta(weeks=1 * n) if unit == "day": - return relativedelta(days=n) - return None + return relativedelta(days=1 * n) + if unit == "hour": + return relativedelta(hours=1 * n) + if unit == "minute": + return relativedelta(minutes=1 * n) + if unit == "second": + return relativedelta(seconds=1 * n) + raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_literal(date): - return exp.cast( - exp.Literal.string(date), - "DATETIME" if isinstance(date, datetime.datetime) else "DATE", - ) + +def date_floor(d: datetime.date, unit: str) -> datetime.date: + if unit == "year": + return d.replace(month=1, day=1) + if unit == "quarter": + if d.month <= 3: + return d.replace(month=1, day=1) + elif d.month <= 6: + return d.replace(month=4, day=1) + elif d.month <= 9: + return d.replace(month=7, day=1) + else: + return d.replace(month=10, day=1) + if unit == "month": + return d.replace(month=d.month, day=1) + if unit == "week": + # Assuming week starts on Monday (0) and ends on Sunday (6) + return d - datetime.timedelta(days=d.weekday()) + if unit == "day": + return d + + raise UnsupportedUnit(f"Unsupported unit: {unit}") + + +def date_ceil(d: datetime.date, unit: str) -> datetime.date: + floor = date_floor(d, unit) + + if floor == d: + return d + + return floor + interval(unit) def boolean_literal(condition): diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 816f5fb..242fc87 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -43,7 +43,11 @@ def unnest(select, parent_select, next_alias_name): predicate = select.find_ancestor(exp.Condition) alias = next_alias_name() - if not predicate or parent_select is not predicate.parent_select: + if ( + not predicate + or parent_select is not predicate.parent_select + or not parent_select.args.get("from") + ): return # This subquery returns a scalar and can just be converted to a cross join diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 06bc1eb..84b2639 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -278,6 +278,7 @@ class Parser(metaclass=_Parser): TokenType.ISNULL, TokenType.INTERVAL, TokenType.KEEP, + TokenType.KILL, TokenType.LEFT, TokenType.LOAD, TokenType.MERGE, @@ -285,6 +286,7 @@ class Parser(metaclass=_Parser): TokenType.NEXT, TokenType.OFFSET, TokenType.ORDINALITY, + TokenType.OVERLAPS, TokenType.OVERWRITE, TokenType.PARTITION, TokenType.PERCENT, @@ -316,6 +318,7 @@ class Parser(metaclass=_Parser): INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END} TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { + TokenType.ANTI, TokenType.APPLY, TokenType.ASOF, TokenType.FULL, @@ -324,6 +327,7 @@ class Parser(metaclass=_Parser): TokenType.NATURAL, TokenType.OFFSET, TokenType.RIGHT, + TokenType.SEMI, TokenType.WINDOW, } @@ -541,6 +545,7 @@ class Parser(metaclass=_Parser): TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.INSERT: lambda self: self._parse_insert(), + TokenType.KILL: lambda self: self._parse_kill(), TokenType.LOAD: lambda self: self._parse_load(), TokenType.MERGE: lambda self: self._parse_merge(), TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), @@ -856,6 +861,8 @@ class Parser(metaclass=_Parser): DISTINCT_TOKENS = {TokenType.DISTINCT} + NULL_TOKENS = {TokenType.NULL} + STRICT_CAST = True # A NULL arg in CONCAT yields NULL by default @@ -873,6 +880,9 @@ class Parser(metaclass=_Parser): # Whether or not the table sample clause expects CSV syntax TABLESAMPLE_CSV = False + # Whether or not the SET command needs a delimiter (e.g. "=") for assignments. + SET_REQUIRES_ASSIGNMENT_DELIMITER = True + __slots__ = ( "error_level", "error_message_context", @@ -1280,7 +1290,14 @@ class Parser(metaclass=_Parser): else: begin = self._match(TokenType.BEGIN) return_ = self._match_text_seq("RETURN") - expression = self._parse_statement() + + if self._match(TokenType.STRING, advance=False): + # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property + # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement + expression = self._parse_string() + extend_props(self._parse_properties()) + else: + expression = self._parse_statement() if return_: expression = self.expression(exp.Return, this=expression) @@ -1400,20 +1417,18 @@ class Parser(metaclass=_Parser): if self._match_text_seq("SQL", "SECURITY"): return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER")) - assignment = self._match_pair( - TokenType.VAR, TokenType.EQ, advance=False - ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) + index = self._index + key = self._parse_column() - if assignment: - key = self._parse_var_or_string() - self._match(TokenType.EQ) - return self.expression( - exp.Property, - this=key, - value=self._parse_column() or self._parse_var(any_token=True), - ) + if not self._match(TokenType.EQ): + self._retreat(index) + return None - return None + return self.expression( + exp.Property, + this=key.to_dot() if isinstance(key, exp.Column) else key, + value=self._parse_column() or self._parse_var(any_token=True), + ) def _parse_stored(self) -> exp.FileFormatProperty: self._match(TokenType.ALIAS) @@ -1818,6 +1833,15 @@ class Parser(metaclass=_Parser): ignore=ignore, ) + def _parse_kill(self) -> exp.Kill: + kind = exp.var(self._prev.text) if self._match_texts(("CONNECTION", "QUERY")) else None + + return self.expression( + exp.Kill, + this=self._parse_primary(), + kind=kind, + ) + def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: conflict = self._match_text_seq("ON", "CONFLICT") duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") @@ -2459,7 +2483,7 @@ class Parser(metaclass=_Parser): index = self._parse_id_var() table = None - using = self._parse_field() if self._match(TokenType.USING) else None + using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None if self._match(TokenType.L_PAREN, advance=False): columns = self._parse_wrapped_csv(self._parse_ordered) @@ -2476,6 +2500,7 @@ class Parser(metaclass=_Parser): primary=primary, amp=amp, partition_by=self._parse_partition_by(), + where=self._parse_where(), ) def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]: @@ -2634,25 +2659,27 @@ class Parser(metaclass=_Parser): return None expressions = self._parse_wrapped_csv(self._parse_type) - ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) + offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) alias = self._parse_table_alias() if with_alias else None - if alias and self.UNNEST_COLUMN_ONLY: - if alias.args.get("columns"): - self.raise_error("Unexpected extra column alias in unnest.") + if alias: + if self.UNNEST_COLUMN_ONLY: + if alias.args.get("columns"): + self.raise_error("Unexpected extra column alias in unnest.") + + alias.set("columns", [alias.this]) + alias.set("this", None) - alias.set("columns", [alias.this]) - alias.set("this", None) + columns = alias.args.get("columns") or [] + if offset and len(expressions) < len(columns): + offset = columns.pop() - offset = None - if self._match_pair(TokenType.WITH, TokenType.OFFSET): + if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET): self._match(TokenType.ALIAS) offset = self._parse_id_var() or exp.to_identifier("offset") - return self.expression( - exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset - ) + return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset) def _parse_derived_table_values(self) -> t.Optional[exp.Values]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) @@ -2940,20 +2967,20 @@ class Parser(metaclass=_Parser): def _parse_ordered(self) -> exp.Ordered: this = self._parse_conjunction() - self._match(TokenType.ASC) - is_desc = self._match(TokenType.DESC) + asc = self._match(TokenType.ASC) + desc = self._match(TokenType.DESC) or (asc and False) + is_nulls_first = self._match_text_seq("NULLS", "FIRST") is_nulls_last = self._match_text_seq("NULLS", "LAST") - desc = is_desc or False - asc = not desc + nulls_first = is_nulls_first or False explicitly_null_ordered = is_nulls_first or is_nulls_last if ( not explicitly_null_ordered and ( - (asc and self.NULL_ORDERING == "nulls_are_small") + (not desc and self.NULL_ORDERING == "nulls_are_small") or (desc and self.NULL_ORDERING != "nulls_are_small") ) and self.NULL_ORDERING != "nulls_are_last" @@ -3227,8 +3254,8 @@ class Parser(metaclass=_Parser): return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) - def _parse_type(self) -> t.Optional[exp.Expression]: - interval = self._parse_interval() + def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: + interval = parse_interval and self._parse_interval() if interval: return interval @@ -3247,7 +3274,7 @@ class Parser(metaclass=_Parser): return self._parse_column() return self._parse_column_ops(data_type) - return this + return this and self._parse_column_ops(this) def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]: this = self._parse_type() @@ -3404,7 +3431,7 @@ class Parser(metaclass=_Parser): return this def _parse_struct_types(self) -> t.Optional[exp.Expression]: - this = self._parse_type() or self._parse_id_var() + this = self._parse_type(parse_interval=False) or self._parse_id_var() self._match(TokenType.COLON) return self._parse_column_def(this) @@ -3847,6 +3874,8 @@ class Parser(metaclass=_Parser): action = "NO ACTION" elif self._match_text_seq("CASCADE"): action = "CASCADE" + elif self._match_text_seq("RESTRICT"): + action = "RESTRICT" elif self._match_pair(TokenType.SET, TokenType.NULL): action = "SET NULL" elif self._match_pair(TokenType.SET, TokenType.DEFAULT): @@ -4573,7 +4602,7 @@ class Parser(metaclass=_Parser): return self._parse_var() or self._parse_string() def _parse_null(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.NULL): + if self._match_set(self.NULL_TOKENS): return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return self._parse_placeholder() @@ -4608,14 +4637,18 @@ class Parser(metaclass=_Parser): return None if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_column) - return self._parse_csv(self._parse_column) + + except_column = self._parse_column() + return [except_column] if except_column else None def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]: if not self._match(TokenType.REPLACE): return None if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_expression) - return self._parse_expressions() + + replace_expression = self._parse_expression() + return [replace_expression] if replace_expression else None def _parse_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA @@ -4931,8 +4964,9 @@ class Parser(metaclass=_Parser): return self._parse_set_transaction(global_=kind == "GLOBAL") left = self._parse_primary() or self._parse_id_var() + assignment_delimiter = self._match_texts(("=", "TO")) - if not self._match_texts(("=", "TO")): + if not left or (self.SET_REQUIRES_ASSIGNMENT_DELIMITER and not assignment_delimiter): self._retreat(index) return None diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index ce255c3..4d5f198 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -247,6 +247,7 @@ class TokenType(AutoName): JOIN = auto() JOIN_MARKER = auto() KEEP = auto() + KILL = auto() LANGUAGE = auto() LATERAL = auto() LEFT = auto() @@ -595,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer): "ISNULL": TokenType.ISNULL, "JOIN": TokenType.JOIN, "KEEP": TokenType.KEEP, + "KILL": TokenType.KILL, "LATERAL": TokenType.LATERAL, "LEFT": TokenType.LEFT, "LIKE": TokenType.LIKE, diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 70b9a31..ac9dd81 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -146,7 +146,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: if isinstance(unnest, exp.Unnest): alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode expression.args["joins"].remove(join) @@ -163,65 +163,134 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: return expression -def explode_to_unnest(expression: exp.Expression) -> exp.Expression: - """Convert explode/posexplode into unnest (used in hive -> presto).""" - if isinstance(expression, exp.Select): - from sqlglot.optimizer.scope import Scope - - taken_select_names = set(expression.named_selects) - taken_source_names = {name for name, _ in Scope(expression).references} - - for select in expression.selects: - to_replace = select - - pos_alias = "" - explode_alias = "" - - if isinstance(select, exp.Alias): - explode_alias = select.alias - select = select.this - elif isinstance(select, exp.Aliases): - pos_alias = select.aliases[0].name - explode_alias = select.aliases[1].name - select = select.this - - if isinstance(select, (exp.Explode, exp.Posexplode)): - is_posexplode = isinstance(select, exp.Posexplode) - - explode_arg = select.this - unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) +def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: + def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: + """Convert explode/posexplode into unnest (used in hive -> presto).""" + if isinstance(expression, exp.Select): + from sqlglot.optimizer.scope import Scope + + taken_select_names = set(expression.named_selects) + taken_source_names = {name for name, _ in Scope(expression).references} + + def new_name(names: t.Set[str], name: str) -> str: + name = find_new_name(names, name) + names.add(name) + return name + + arrays: t.List[exp.Condition] = [] + series_alias = new_name(taken_select_names, "pos") + series = exp.alias_( + exp.Unnest( + expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] + ), + new_name(taken_source_names, "_u"), + table=[series_alias], + ) + + # we use list here because expression.selects is mutated inside the loop + for select in expression.selects.copy(): + explode = select.find(exp.Explode, exp.Posexplode) + + if isinstance(explode, (exp.Explode, exp.Posexplode)): + pos_alias = "" + explode_alias = "" + + if isinstance(select, exp.Alias): + explode_alias = select.alias + alias = select + elif isinstance(select, exp.Aliases): + pos_alias = select.aliases[0].name + explode_alias = select.aliases[1].name + alias = select.replace(exp.alias_(select.this, "", copy=False)) + else: + alias = select.replace(exp.alias_(select, "")) + explode = alias.find(exp.Explode, exp.Posexplode) + assert explode + + is_posexplode = isinstance(explode, exp.Posexplode) + explode_arg = explode.this + + # This ensures that we won't use [POS]EXPLODE's argument as a new selection + if isinstance(explode_arg, exp.Column): + taken_select_names.add(explode_arg.output_name) + + unnest_source_alias = new_name(taken_source_names, "_u") + + if not explode_alias: + explode_alias = new_name(taken_select_names, "col") + + if is_posexplode: + pos_alias = new_name(taken_select_names, "pos") + + if not pos_alias: + pos_alias = new_name(taken_select_names, "pos") + + alias.set("alias", exp.to_identifier(explode_alias)) + + column = exp.If( + this=exp.column(series_alias).eq(exp.column(pos_alias)), + true=exp.column(explode_alias), + ) - # This ensures that we won't use [POS]EXPLODE's argument as a new selection - if isinstance(explode_arg, exp.Column): - taken_select_names.add(explode_arg.output_name) + explode.replace(column) - unnest_source_alias = find_new_name(taken_source_names, "_u") - taken_source_names.add(unnest_source_alias) + if is_posexplode: + expressions = expression.expressions + expressions.insert( + expressions.index(alias) + 1, + exp.If( + this=exp.column(series_alias).eq(exp.column(pos_alias)), + true=exp.column(pos_alias), + ).as_(pos_alias), + ) + expression.set("expressions", expressions) + + if not arrays: + if expression.args.get("from"): + expression.join(series, copy=False) + else: + expression.from_(series, copy=False) + + size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) + arrays.append(size) + + # trino doesn't support left join unnest with on conditions + # if it did, this would be much simpler + expression.join( + exp.alias_( + exp.Unnest( + expressions=[explode_arg.copy()], + offset=exp.to_identifier(pos_alias), + ), + unnest_source_alias, + table=[explode_alias], + ), + join_type="CROSS", + copy=False, + ) - if not explode_alias: - explode_alias = find_new_name(taken_select_names, "col") - taken_select_names.add(explode_alias) + if index_offset != 1: + size = size - 1 - if is_posexplode: - pos_alias = find_new_name(taken_select_names, "pos") - taken_select_names.add(pos_alias) + expression.where( + exp.column(series_alias) + .eq(exp.column(pos_alias)) + .or_( + (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) + ), + copy=False, + ) - if is_posexplode: - column_names = [explode_alias, pos_alias] - to_replace.pop() - expression.select(pos_alias, explode_alias, copy=False) - else: - column_names = [explode_alias] - to_replace.replace(exp.column(explode_alias)) + if arrays: + end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) - unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) + if index_offset != 1: + end = end - (1 - index_offset) + series.expressions[0].set("end", end) - if not expression.args.get("from"): - expression.from_(unnest, copy=False) - else: - expression.join(unnest, join_type="CROSS", copy=False) + return expression - return expression + return _explode_to_unnest PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) @@ -283,6 +352,31 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: return expression +def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Timestamp) and not expression.expression: + return exp.cast( + expression.this, + to=exp.DataType.Type.TIMESTAMP, + ) + return expression + + +def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + for join in expression.args.get("joins") or []: + on = join.args.get("on") + if on and join.kind in ("SEMI", "ANTI"): + subquery = exp.select("1").from_(join.this).where(on) + exists = exp.Exists(this=subquery) + if join.kind == "ANTI": + exists = exists.not_(copy=False) + + join.pop() + expression.where(exists, copy=False) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: @@ -327,12 +421,3 @@ def preprocess( raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") return _to_sql - - -def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Timestamp) and not expression.expression: - return exp.cast( - expression.this, - to=exp.DataType.Type.TIMESTAMP, - ) - return expression |