diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 32 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 41 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 29 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 18 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 51 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 38 |
16 files changed, 244 insertions, 62 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6c71624..1349c56 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( date_add_interval_sql, datestrtodate_sql, format_time_lambda, + if_sql, inline_array_sql, json_keyvalue_comma_sql, max_or_greatest, @@ -176,6 +177,8 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False + LOG_BASE_FIRST = False # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -256,7 +259,6 @@ class BigQuery(Dialect): "RECORD": TokenType.STRUCT, "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, - "UNKNOWN": TokenType.NULL, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } KEYWORDS.pop("DIV") @@ -264,7 +266,6 @@ class BigQuery(Dialect): class Parser(parser.Parser): PREFIXED_PIVOT_COLUMNS = True - LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True FUNCTIONS = { @@ -292,9 +293,7 @@ class BigQuery(Dialect): expression=seq_get(args, 1), position=seq_get(args, 2), occurrence=seq_get(args, 3), - group=exp.Literal.number(1) - if re.compile(str(seq_get(args, 1))).groups == 1 - else None, + group=exp.Literal.number(1) if re.compile(args[1].name).groups == 1 else None, ), "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), @@ -344,6 +343,11 @@ class BigQuery(Dialect): "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()), } + RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy() + RANGE_PARSERS.pop(TokenType.OVERLAPS, None) + + NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN} + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: this = super()._parse_table_part(schema=schema) or self._parse_number() @@ -413,8 +417,8 @@ class BigQuery(Dialect): TABLE_HINTS = False LIMIT_FETCH = "LIMIT" RENAME_TABLE_WITH_DB = False - ESCAPE_LINE_BREAK = True NVL2_SUPPORTED = False + UNNEST_WITH_ORDINALITY = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -434,6 +438,7 @@ class BigQuery(Dialect): exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GroupConcat: rename_func("STRING_AGG"), exp.Hex: rename_func("TO_HEX"), + exp.If: if_sql(false_value="NULL"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), @@ -455,10 +460,11 @@ class BigQuery(Dialect): exp.ReturnsProperty: _returnsproperty_sql, exp.Select: transforms.preprocess( [ - transforms.explode_to_unnest, + transforms.explode_to_unnest(), _unqualify_unnest, transforms.eliminate_distinct_on, _alias_ordered_group, + transforms.eliminate_semi_and_anti_joins, ] ), exp.SHA2: lambda self, e: self.func( @@ -514,6 +520,18 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore + { + "\a": "\\a", + "\b": "\\b", + "\f": "\\f", + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", + "\v": "\\v", + } + ) + # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords RESERVED_KEYWORDS = { *generator.Generator.RESERVED_KEYWORDS, diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index d552f4c..7446081 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -113,15 +113,11 @@ class ClickHouse(Dialect): *parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF, - TokenType.ANTI, - TokenType.SEMI, TokenType.ARRAY, } - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.ANY, - TokenType.SEMI, - TokenType.ANTI, TokenType.SETTINGS, TokenType.FORMAT, TokenType.ARRAY, diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 6ec0487..39daad7 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -51,8 +51,6 @@ class Databricks(Spark): exp.ToChar: lambda self, e: self.function_fallback_sql(e), } - PARAMETER_TOKEN = "$" - class Tokenizer(Spark.Tokenizer): HEX_STRINGS = [] diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index d4811c5..ccf04da 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -125,6 +125,12 @@ class _Dialect(type): if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + if not klass.SUPPORTS_SEMI_ANTI_JOIN: + klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { + TokenType.ANTI, + TokenType.SEMI, + } + klass.generator_class.can_identify = klass.can_identify return klass @@ -156,9 +162,15 @@ class Dialect(metaclass=_Dialect): # Determines whether or not user-defined data types are supported SUPPORTS_USER_DEFINED_TYPES = True + # Determines whether or not SEMI/ANTI JOINs are supported + SUPPORTS_SEMI_ANTI_JOIN = True + # Determines how function names are going to be normalized NORMALIZE_FUNCTIONS: bool | str = "upper" + # Determines whether the base comes first in the LOG function + LOG_BASE_FIRST = True + # Indicates the default null ordering method to use if not explicitly set # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" NULL_ORDERING = "nulls_are_small" @@ -331,10 +343,18 @@ def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) - return self.func("APPROX_COUNT_DISTINCT", expression.this) -def if_sql(self: Generator, expression: exp.If) -> str: - return self.func( - "IF", expression.this, expression.args.get("true"), expression.args.get("false") - ) +def if_sql( + name: str = "IF", false_value: t.Optional[exp.Expression | str] = None +) -> t.Callable[[Generator, exp.If], str]: + def _if_sql(self: Generator, expression: exp.If) -> str: + return self.func( + name, + expression.this, + expression.args.get("true"), + expression.args.get("false") or false_value, + ) + + return _if_sql def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: @@ -751,6 +771,12 @@ def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: return self.func("MAX", expression.this) +def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: + a = self.sql(expression.left) + b = self.sql(expression.right) + return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" + + # Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" @@ -764,3 +790,10 @@ def is_parse_json(expression: exp.Expression) -> bool: def isnull_to_is_null(args: t.List) -> exp.Expression: return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) + + +def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: + if expression.expression.args.get("with"): + expression = expression.copy() + expression.set("with", expression.expression.args["with"].pop()) + return self.insert_sql(expression) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 87fb9b5..8b2e708 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -40,6 +40,7 @@ class Drill(Dialect): DATEINT_FORMAT = "'yyyyMMdd'" TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False TIME_MAPPING = { "y": "%Y", @@ -135,7 +136,9 @@ class Drill(Dialect): exp.StrPosition: str_position_sql, exp.StrToDate: _str_to_date, exp.Pow: rename_func("POW"), - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index ab7a26a..352f11a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, binary_from_function, + bool_xor_sql, date_trunc_to_time, datestrtodate_sql, encode_decode_sql, @@ -190,6 +191,11 @@ class DuckDB(Dialect): ), } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { + TokenType.SEMI, + TokenType.ANTI, + } + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -224,6 +230,7 @@ class DuckDB(Dialect): STRUCT_DELIMITER = ("(", ")") RENAME_TABLE_WITH_DB = False NVL2_SUPPORTED = False + SEMI_ANTI_JOIN_WITH_SIDE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -234,7 +241,7 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), - exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression), + exp.BitwiseXor: rename_func("XOR"), exp.CommentColumnConstraint: no_comment_column_constraint_sql, exp.CurrentDate: lambda self, e: "CURRENT_DATE", exp.CurrentTime: lambda self, e: "CURRENT_TIME", @@ -301,6 +308,7 @@ class DuckDB(Dialect): exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", exp.VariancePop: rename_func("VAR_POP"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.Xor: bool_xor_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index bec27d3..a427870 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -111,7 +111,7 @@ def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: def _property_sql(self: Hive.Generator, expression: exp.Property) -> str: - return f"'{expression.name}'={self.sql(expression, 'value')}" + return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str: @@ -413,7 +413,7 @@ class Hive(Dialect): exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), - exp.If: if_sql, + exp.If: if_sql(), exp.ILike: no_ilike_sql, exp.IsNan: rename_func("ISNAN"), exp.JSONExtract: rename_func("GET_JSON_OBJECT"), @@ -466,6 +466,11 @@ class Hive(Dialect): exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.LastDateOfMonth: rename_func("LAST_DAY"), exp.National: lambda self, e: self.national_sql(e, prefix=""), + exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", + exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", + exp.NotForReplicationColumnConstraint: lambda self, e: "", + exp.OnProperty: lambda self, e: "", + exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY", } PROPERTIES_LOCATION = { @@ -475,6 +480,35 @@ class Hive(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def parameter_sql(self, expression: exp.Parameter) -> str: + this = self.sql(expression, "this") + parent = expression.parent + + if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem): + # We need to produce SET key = value instead of SET ${key} = value + return this + + return f"${{{this}}}" + + def schema_sql(self, expression: exp.Schema) -> str: + expression = expression.copy() + + for ordered in expression.find_all(exp.Ordered): + if ordered.args.get("desc") is False: + ordered.set("desc", None) + + return super().schema_sql(expression) + + def constraint_sql(self, expression: exp.Constraint) -> str: + expression = expression.copy() + + for prop in list(expression.find_all(exp.Properties)): + prop.pop() + + this = self.sql(expression, "this") + expressions = self.expressions(expression, sep=" ", flat=True) + return f"CONSTRAINT {this} {expressions}" + def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str: serde_props = self.sql(expression, "serde_properties") serde_props = f" {serde_props}" if serde_props else "" diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 75660f8..554241d 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -102,6 +102,7 @@ class MySQL(Dialect): TIME_FORMAT = "'%Y-%m-%d %T'" DPIPE_IS_STRING_CONCAT = False SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions TIME_MAPPING = { @@ -519,7 +520,7 @@ class MySQL(Dialect): return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") - def _parse_type(self) -> t.Optional[exp.Expression]: + def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: # mysql binary is special and can work anywhere, even in order by operations # it operates like a no paren func if self._match(TokenType.BINARY, advance=False): @@ -528,7 +529,7 @@ class MySQL(Dialect): if isinstance(data_type, exp.DataType): return self.expression(exp.Cast, this=self._parse_column(), to=data_type) - return super()._parse_type() + return super()._parse_type(parse_interval=parse_interval) class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -560,7 +561,9 @@ class MySQL(Dialect): exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.Pivot: no_pivot_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index d049d8e..342fd95 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + bool_xor_sql, datestrtodate_sql, format_time_lambda, max_or_greatest, @@ -110,7 +111,7 @@ def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> st def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): - return f"{self.expressions(expression, flat=True)}[]" + return f"{self.expressions(expression, flat=True)}[]" if expression.expressions else "ARRAY" return self.datatype_sql(expression) @@ -380,25 +381,29 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.DateAdd: _date_add_sql("+"), + exp.DateDiff: _date_diff_sql, + exp.DateStrToDate: datestrtodate_sql, + exp.DataType: _datatype_sql, + exp.DateSub: _date_add_sql("-"), exp.Explode: rename_func("UNNEST"), + exp.GroupConcat: _string_agg_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), - exp.Pow: lambda self, e: self.binary(e, "^"), - exp.CurrentDate: no_paren_current_date_sql, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.DateAdd: _date_add_sql("+"), - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("-"), - exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Max: max_or_greatest, @@ -412,8 +417,10 @@ class Postgres(Dialect): [transforms.add_within_group_for_percentiles] ), exp.Pivot: no_pivot_sql, + exp.Pow: lambda self, e: self.binary(e, "^"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), + exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, @@ -426,11 +433,7 @@ class Postgres(Dialect): exp.TryCast: no_trycast_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", - exp.DataType: _datatype_sql, - exp.GroupConcat: _string_agg_sql, - exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" - if isinstance(seq_get(e.expressions, 0), exp.Select) - else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", + exp.Xor: bool_xor_sql, } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9ae4c32..0d8d4ab 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, binary_from_function, + bool_xor_sql, date_trunc_to_time, encode_decode_sql, format_time_lambda, @@ -40,7 +41,7 @@ def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> s this=exp.Unnest( expressions=[expression.this.this], alias=expression.args.get("alias"), - ordinality=isinstance(expression.this, exp.Posexplode), + offset=isinstance(expression.this, exp.Posexplode), ), kind="cross", ) @@ -173,6 +174,7 @@ class Presto(Dialect): TIME_FORMAT = MySQL.TIME_FORMAT TIME_MAPPING = MySQL.TIME_MAPPING STRICT_STRING_CONCAT = True + SUPPORTS_SEMI_ANTI_JOIN = False # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -308,7 +310,7 @@ class Presto(Dialect): exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), - exp.If: if_sql, + exp.If: if_sql(), exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, exp.ParseJSON: rename_func("JSON_PARSE"), @@ -331,7 +333,8 @@ class Presto(Dialect): [ transforms.eliminate_qualify, transforms.eliminate_distinct_on, - transforms.explode_to_unnest, + transforms.explode_to_unnest(1), + transforms.eliminate_semi_and_anti_joins, ] ), exp.SortArray: _no_sort_array, @@ -340,7 +343,6 @@ class Presto(Dialect): exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", - exp.Struct: rename_func("ROW"), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, @@ -363,8 +365,16 @@ class Presto(Dialect): [transforms.remove_within_group_for_percentiles] ), exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]), + exp.Xor: bool_xor_sql, } + def struct_sql(self, expression: exp.Struct) -> str: + if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions): + self.unsupported("Struct with key-value definitions is unsupported.") + return self.function_fallback_sql(expression) + + return rename_func("ROW")(self, expression) + def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") if expression.this and unit.lower().startswith("week"): diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b4c7664..2145844 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -138,7 +138,9 @@ class Redshift(Postgres): exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, exp.SafeConcat: concat_to_dpipe_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5aa946e..5c49331 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -5,9 +5,11 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + binary_from_function, date_trunc_to_time, datestrtodate_sql, format_time_lambda, + if_sql, inline_array_sql, max_or_greatest, min_or_least, @@ -203,6 +205,7 @@ class Snowflake(Dialect): NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False + SUPPORTS_SEMI_ANTI_JOIN = False TIME_MAPPING = { "YYYY": "%Y", @@ -240,7 +243,16 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, + "ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries( + # ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive + start=seq_get(args, 0), + end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)), + step=seq_get(args, 2), + ), "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "BITXOR": binary_from_function(exp.BitwiseXor), + "BIT_XOR": binary_from_function(exp.BitwiseXor), + "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( @@ -277,7 +289,7 @@ class Snowflake(Dialect): ), } - TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME} + TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME} RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, @@ -381,6 +393,7 @@ class Snowflake(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + AGGREGATE_FILTER_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -390,6 +403,7 @@ class Snowflake(Dialect): exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), + exp.BitwiseXor: rename_func("BITXOR"), exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), exp.DateDiff: lambda self, e: self.func( "DATEDIFF", e.text("unit"), e.expression, e.this @@ -398,8 +412,11 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.Extract: rename_func("DATE_PART"), + exp.GenerateSeries: lambda self, e: self.func( + "ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step") + ), exp.GroupConcat: rename_func("LISTAGG"), - exp.If: rename_func("IFF"), + exp.If: if_sql(name="IFF", false_value="NULL"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -407,7 +424,13 @@ class Snowflake(Dialect): exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.RegexpILike: _regexpilike_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_distinct_on, + transforms.explode_to_unnest(0), + transforms.eliminate_semi_and_anti_joins, + ] + ), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), exp.StrPosition: lambda self, e: self.func( @@ -431,6 +454,7 @@ class Snowflake(Dialect): exp.UnixToTime: _unix_to_time_sql, exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.Xor: rename_func("BOOLXOR"), } TYPE_MAPPING = { @@ -449,6 +473,27 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def unnest_sql(self, expression: exp.Unnest) -> str: + selects = ["value"] + unnest_alias = expression.args.get("alias") + + offset = expression.args.get("offset") + if offset: + if unnest_alias: + expression = expression.copy() + unnest_alias.append("columns", offset.pop()) + + selects.append("index") + + subquery = exp.Subquery( + this=exp.select(*selects).from_( + f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))" + ), + ) + alias = self.sql(unnest_alias) + alias = f" AS {alias}" if alias else "" + return f"{self.sql(subquery)}{alias}" + def show_sql(self, expression: exp.Show) -> str: scope = self.sql(expression, "scope") scope = f" {scope}" if scope else "" diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 56d33ba..3dc9838 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, format_time_lambda, is_parse_json, + move_insert_cte_sql, pivot_column_names, rename_func, trim_sql, @@ -115,13 +116,6 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression -def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str: - if expression.expression.args.get("with"): - expression = expression.copy() - expression.set("with", expression.expression.args.pop("with")) - return self.insert_sql(expression) - - class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { @@ -206,7 +200,7 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.Insert: _insert_sql, + exp.Insert: move_insert_cte_sql, exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 7bfdf1c..1edfa9d 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -64,6 +64,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: class SQLite(Dialect): # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + SUPPORTS_SEMI_ANTI_JOIN = False class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -125,7 +126,11 @@ class SQLite(Dialect): exp.Pivot: no_pivot_sql, exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_qualify] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_qualify, + transforms.eliminate_semi_and_anti_joins, + ] ), exp.TableSample: no_tablesample_sql, exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index d9de968..b9e925a 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -8,6 +8,8 @@ from sqlglot.tokens import TokenType class Teradata(Dialect): + SUPPORTS_SEMI_ANTI_JOIN = False + TIME_MAPPING = { "Y": "%Y", "YYYY": "%Y", @@ -168,7 +170,9 @@ class Teradata(Dialect): **generator.Generator.TRANSFORMS, exp.Max: max_or_greatest, exp.Min: min_or_least, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 2299310..fa62e78 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( any_value_to_max_sql, max_or_greatest, min_or_least, + move_insert_cte_sql, parse_date_delta, rename_func, timestrtotime_sql, @@ -206,6 +207,8 @@ class TSQL(Dialect): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" + SUPPORTS_SEMI_ANTI_JOIN = False + LOG_BASE_FIRST = False TIME_MAPPING = { "year": "%Y", @@ -345,6 +348,8 @@ class TSQL(Dialect): } class Parser(parser.Parser): + SET_REQUIRES_ASSIGNMENT_DELIMITER = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( @@ -396,7 +401,6 @@ class TSQL(Dialect): TokenType.END: lambda self: self._parse_command(), } - LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True CONCAT_NULL_OUTPUTS_STRING = True @@ -609,11 +613,14 @@ class TSQL(Dialect): exp.Extract: rename_func("DATEPART"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.Insert: move_insert_cte_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, exp.NumberToStr: _format_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + ), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( "HASHBYTES", @@ -632,6 +639,14 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def setitem_sql(self, expression: exp.SetItem) -> str: + this = expression.this + if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter): + # T-SQL does not use '=' in SET command, except when the LHS is a variable. + return f"{self.sql(this.left)} {self.sql(this.right)}" + + return super().setitem_sql(expression) + def boolean_sql(self, expression: exp.Boolean) -> str: if type(expression.parent) in BIT_TYPES: return "1" if expression.this else "0" @@ -661,16 +676,27 @@ class TSQL(Dialect): exists = expression.args.pop("exists", None) sql = super().create_sql(expression) + table = expression.find(exp.Table) + + if kind == "TABLE" and expression.expression: + sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp" + if exists: - table = expression.find(exp.Table) identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) + sql = self.sql(exp.Literal.string(sql)) if kind == "SCHEMA": - sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')""" + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC({sql})""" elif kind == "TABLE": - sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')""" + assert table + where = exp.and_( + exp.column("table_name").eq(table.name), + exp.column("table_schema").eq(table.db) if table.db else None, + exp.column("table_catalog").eq(table.catalog) if table.catalog else None, + ) + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE {where}) EXEC({sql})""" elif kind == "INDEX": index = self.sql(exp.Literal.string(expression.this.text("this"))) - sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')""" + sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC({sql})""" elif expression.args.get("replace"): sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) |