diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-11-01 05:12:42 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-11-01 05:12:42 +0000 |
commit | c51a9844b869fd7cd69e5cc7658d34f61a865185 (patch) | |
tree | 55706c65ce7e19626aabf7ff4dde0e1a51b739db /sqlglot | |
parent | Releasing debian version 18.17.0-1. (diff) | |
download | sqlglot-c51a9844b869fd7cd69e5cc7658d34f61a865185.tar.xz sqlglot-c51a9844b869fd7cd69e5cc7658d34f61a865185.zip |
Merging upstream version 19.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
28 files changed, 477 insertions, 208 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index be10f3d..35feaad 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -158,6 +158,6 @@ def transpile( """ write = (read if write is None else write) if identity else write return [ - Dialect.get_or_raise(write)().generate(expression, **opts) + Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else "" for expression in parse(sql, read, error_level=error_level) ] diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 51baba2..fc9a3ae 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -69,7 +69,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: returns = expression.find(exp.ReturnsProperty) if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): - expression = expression.copy() expression.set("kind", "TABLE FUNCTION") if isinstance(expression.expression, (exp.Subquery, exp.Literal)): @@ -699,6 +698,5 @@ class BigQuery(Dialect): def version_sql(self, expression: exp.Version) -> str: if expression.name == "TIMESTAMP": - expression = expression.copy() expression.set("this", "SYSTEM_TIME") return super().version_sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 30f728c..394a922 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -461,7 +461,6 @@ class ClickHouse(Dialect): def safeconcat_sql(self, expression: exp.SafeConcat) -> str: # Clickhouse errors out if we try to cast a NULL value to TEXT - expression = expression.copy() return self.func( "CONCAT", *[ diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 314a821..b777db0 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -35,7 +35,7 @@ class Databricks(Spark): exp.DatetimeSub: lambda self, e: self.func( "TIMESTAMPADD", e.text("unit"), - exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)), + exp.Mul(this=e.expression, expression=exp.Literal.number(-1)), e.this, ), exp.DatetimeDiff: lambda self, e: self.func( @@ -63,21 +63,14 @@ class Databricks(Spark): and kind.this in exp.DataType.INTEGER_TYPES ): # only BIGINT generated identity constraints are supported - expression = expression.copy() expression.set("kind", exp.DataType.build("bigint")) return super().columndef_sql(expression, sep) def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: - expression = expression.copy() expression.set("this", True) # trigger ALWAYS in super class return super().generatedasidentitycolumnconstraint_sql(expression) class Tokenizer(Spark.Tokenizer): HEX_STRINGS = [] - - SINGLE_TOKENS = { - **Spark.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 739e8d7..21e7889 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -315,11 +315,14 @@ class Dialect(metaclass=_Dialect): ) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) - def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: - return self.generator(**opts).generate(expression) + def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: + return self.generator(**opts).generate(expression, copy=copy) def transpile(self, sql: str, **opts) -> t.List[str]: - return [self.generate(expression, **opts) for expression in self.parse(sql)] + return [ + self.generate(expression, copy=False, **opts) if expression else "" + for expression in self.parse(sql) + ] def tokenize(self, sql: str) -> t.List[Token]: return self.tokenizer.tokenize(sql) @@ -380,9 +383,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like( - this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() - ) + exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) ) @@ -518,7 +519,6 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") if has_schema and is_partitionable: - expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) if prop and prop.this and not isinstance(prop.this, exp.Schema): schema = expression.this @@ -583,7 +583,7 @@ def date_add_interval_sql( this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression.copy(), unit=unit) + interval = exp.Interval(this=expression.expression, unit=unit) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func @@ -621,7 +621,6 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: - expression = expression.copy() return self.sql( exp.Substring( this=expression.this, start=exp.Literal.number(1), length=expression.expression @@ -630,7 +629,6 @@ def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: - expression = expression.copy() return self.sql( exp.Substring( this=expression.this, @@ -675,7 +673,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this.expressions[0] self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") - return self.func("sum", exp.func("if", cond.copy(), 1, 0)) + return self.func("sum", exp.func("if", cond, 1, 0)) def trim_sql(self: Generator, expression: exp.Trim) -> str: @@ -716,12 +714,10 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: - expression = expression.copy() return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: - expression = expression.copy() delim, *rest_args = expression.expressions return self.sql( reduce( @@ -809,13 +805,6 @@ def isnull_to_is_null(args: t.List) -> exp.Expression: return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) -def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: - if expression.expression.args.get("with"): - expression = expression.copy() - expression.set("with", expression.expression.args["with"].pop()) - return self.insert_sql(expression) - - def generatedasidentitycolumnconstraint_sql( self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 8b2e708..42453fd 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -20,7 +20,9 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") - return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -147,7 +149,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 287e03a..d8d9f90 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -36,14 +36,14 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" + return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" op = "+" if isinstance(expression, exp.DateAdd) else "-" - return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" + return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" # BigQuery -> DuckDB conversion for the DATE function @@ -365,7 +365,7 @@ class DuckDB(Dialect): multiplier = 90 if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})" + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" return super().interval_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 7bff553..3b1c8de 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -53,8 +53,6 @@ DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") def _create_sql(self, expression: exp.Create) -> str: - expression = expression.copy() - # remove UNIQUE column constraints for constraint in expression.find_all(exp.UniqueColumnConstraint): if constraint.parent: @@ -88,7 +86,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) - if expression.expression.is_number: modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) else: - modified_increment = expression.expression.copy() + modified_increment = expression.expression if multiplier != 1: modified_increment = exp.Mul( # type: ignore this=modified_increment, expression=exp.Literal.number(multiplier) @@ -229,6 +227,11 @@ class Hive(Dialect): STRING_ESCAPES = ["\\"] ENCODE = "utf-8" + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ADD ARCHIVE": TokenType.COMMAND, @@ -408,6 +411,7 @@ class Hive(Dialect): INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False + SUPPORTS_NESTED_CTES = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -521,7 +525,10 @@ class Hive(Dialect): def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + parent = expression.parent + this = f"{this}:{expression_sql}" if expression_sql else this if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem): # We need to produce SET key = value instead of SET ${key} = value @@ -530,8 +537,6 @@ class Hive(Dialect): return f"${{{this}}}" def schema_sql(self, expression: exp.Schema) -> str: - expression = expression.copy() - for ordered in expression.find_all(exp.Ordered): if ordered.args.get("desc") is False: ordered.set("desc", None) @@ -539,8 +544,6 @@ class Hive(Dialect): return super().schema_sql(expression) def constraint_sql(self, expression: exp.Constraint) -> str: - expression = expression.copy() - for prop in list(expression.find_all(exp.Properties)): prop.pop() diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 2185a85..c78aa9e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: return f"STR_TO_DATE({concat}, '{date_format}')" -def _str_to_date(args: t.List) -> exp.StrToDate: - date_format = MySQL.format_time(seq_get(args, 1)) - return exp.StrToDate(this=seq_get(args, 0), format=date_format) +# All specifiers for time parts (as opposed to date parts) +# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format +TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"} + + +def _has_time_specifier(date_format: str) -> bool: + i = 0 + length = len(date_format) + + while i < length: + if date_format[i] == "%": + i += 1 + if i < length and date_format[i] in TIME_SPECIFIERS: + return True + i += 1 + return False + + +def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime: + mysql_date_format = seq_get(args, 1) + date_format = MySQL.format_time(mysql_date_format) + this = seq_get(args, 0) + + if mysql_date_format and _has_time_specifier(mysql_date_format.name): + return exp.StrToTime(this=this, format=date_format) + + return exp.StrToDate(this=this, format=date_format) def _str_to_date_sql( @@ -93,7 +117,9 @@ def _date_add_sql( def func(self: MySQL.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date( args: t.Tuple[str, ...] = ("this",), ) -> t.Callable[[MySQL.Generator, exp.Func], str]: def func(self: MySQL.Generator, expression: exp.Func) -> str: - expression = expression.copy() - for arg_key in args: arg = expression.args.get(arg_key) if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): @@ -629,6 +653,7 @@ class MySQL(Dialect): transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins, transforms.eliminate_qualify, + transforms.eliminate_full_outer_join, ] ), exp.StrPosition: strposition_to_locate_sql, @@ -728,7 +753,6 @@ class MySQL(Dialect): to = self.CAST_MAPPING.get(expression.to.this) if to: - expression = expression.copy() expression.to.set("this", to) return super().cast_sql(expression) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 086b278..27c6851 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -43,8 +43,6 @@ DATE_DIFF_FACTOR = { def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - expression = expression.copy() - this = self.sql(expression, "this") unit = expression.args.get("unit") @@ -96,7 +94,6 @@ def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str: def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str: - expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") order = "" @@ -119,7 +116,6 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: auto = expression.find(exp.AutoIncrementColumnConstraint) if auto: - expression = expression.copy() expression.args["constraints"].remove(auto.parent) kind = expression.args["kind"] @@ -134,7 +130,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: def _serial_to_generated(expression: exp.Expression) -> exp.Expression: - kind = expression.args["kind"] + kind = expression.args.get("kind") + if not kind: + return expression if kind.this == exp.DataType.Type.SERIAL: data_type = exp.DataType(this=exp.DataType.Type.INT) @@ -146,7 +144,6 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression: data_type = None if data_type: - expression = expression.copy() expression.args["kind"].replace(data_type) constraints = expression.args["constraints"] generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) @@ -409,6 +406,7 @@ class Postgres(Dialect): exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, exp.Merge: transforms.preprocess([_remove_target_from_merge]), + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess( [transforms.add_within_group_for_percentiles] ), @@ -445,6 +443,7 @@ class Postgres(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -452,7 +451,6 @@ class Postgres(Dialect): def bracket_sql(self, expression: exp.Bracket) -> str: """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY.""" if isinstance(expression.this, exp.Array): - expression = expression.copy() expression.set("this", exp.paren(expression.this, copy=False)) return super().bracket_sql(expression) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index aac368c..ded3655 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -36,7 +36,6 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, exp.Explode): - expression = expression.copy() return self.sql( exp.Join( this=exp.Unnest( @@ -72,7 +71,6 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: for schema in expression.parent.find_all(exp.Schema): column_defs = schema.find_all(exp.ColumnDef) if column_defs and isinstance(schema.parent, exp.Property): - expression = expression.copy() expression.expressions.extend(column_defs) return self.schema_sql(expression) @@ -407,12 +405,10 @@ class Presto(Dialect): target_type = None if target_type and target_type.is_type("timestamp"): - to = target_type.copy() - if target_type is start.to: - end = exp.cast(end, to) + end = exp.cast(end, target_type) else: - start = exp.cast(start, to) + start = exp.cast(start, target_type) return self.func("SEQUENCE", start, end, step) @@ -432,6 +428,5 @@ class Presto(Dialect): kind = expression.args["kind"] schema = expression.this if kind == "VIEW" and schema.expressions: - expression = expression.copy() expression.this.set("expressions", None) return super().create_sql(expression) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index df70aa7..6c7ba35 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -27,6 +27,14 @@ def _parse_date_add(args: t.List) -> exp.DateAdd: ) +def _parse_datediff(args: t.List) -> exp.DateDiff: + return exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 2)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), + unit=seq_get(args, 0), + ) + + class Redshift(Postgres): # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -51,11 +59,9 @@ class Redshift(Postgres): ), "DATEADD": _parse_date_add, "DATE_ADD": _parse_date_add, - "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 2)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), - unit=seq_get(args, 0), - ), + "DATEDIFF": _parse_datediff, + "DATE_DIFF": _parse_datediff, + "LISTAGG": exp.GroupConcat.from_arg_list, "STRTOL": exp.FromBase.from_arg_list, } @@ -175,6 +181,7 @@ class Redshift(Postgres): exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, + exp.GroupConcat: rename_func("LISTAGG"), exp.ParseJSON: rename_func("JSON_PARSE"), exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( @@ -207,7 +214,6 @@ class Redshift(Postgres): `TEXT` to `VARCHAR`. """ if expression.is_type("text"): - expression = expression.copy() expression.set("this", exp.DataType.Type.VARCHAR) precision = expression.args.get("expressions") diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 07be65b..01f7512 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -32,7 +32,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -60,8 +60,8 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: # reduce it using `simplify_literals` first and then check if it's a Literal. first_arg = seq_get(args, 0) if not isinstance(simplify_literals(first_arg, root=True), Literal): - # case: <variant_expr> - return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + # case: <variant_expr> or other expressions such as columns + return exp.TimeStrToTime.from_arg_list(args) if first_arg.is_string: if _check_int(first_arg.this): @@ -560,7 +560,6 @@ class Snowflake(Dialect): offset = expression.args.get("offset") if offset: if unnest_alias: - expression = expression.copy() unnest_alias.append("columns", offset.pop()) selects.append("index") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 8461920..1abfce6 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -63,6 +63,8 @@ class Spark(Spark2): return this class Generator(Spark2.Generator): + SUPPORTS_NESTED_CTES = True + TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, exp.DataType.Type.MONEY: "DECIMAL(15, 4)", diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 2fd4f4e..da84bd8 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import ( binary_from_function, format_time_lambda, is_parse_json, - move_insert_cte_sql, pivot_column_names, rename_func, trim_sql, @@ -70,7 +69,9 @@ def _unalias_pivot(expression: exp.Expression) -> exp.Expression: alias = pivot.args["alias"].pop() return exp.From( this=expression.this.replace( - exp.select("*").from_(expression.this.copy()).subquery(alias=alias) + exp.select("*") + .from_(expression.this.copy(), copy=False) + .subquery(alias=alias, copy=False) ) ) @@ -188,7 +189,6 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.Insert: move_insert_cte_sql, exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 1edfa9d..1fa730d 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -50,7 +50,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: else: for column in defs.values(): auto_increment = None - for constraint in column.constraints.copy(): + for constraint in column.constraints: if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint): break if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 152afa6..e8162c2 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -38,12 +38,15 @@ class Teradata(Dialect): "^=": TokenType.NEQ, "BYTEINT": TokenType.SMALLINT, "COLLECT": TokenType.COMMAND, + "DEL": TokenType.DELETE, + "EQ": TokenType.EQ, "GE": TokenType.GTE, "GT": TokenType.GT, "HELP": TokenType.COMMAND, "INS": TokenType.INSERT, "LE": TokenType.LTE, "LT": TokenType.LT, + "MINUS": TokenType.EXCEPT, "MOD": TokenType.MOD, "NE": TokenType.NEQ, "NOT=": TokenType.NEQ, @@ -51,6 +54,7 @@ class Teradata(Dialect): "SEL": TokenType.SELECT, "ST_GEOMETRY": TokenType.GEOMETRY, "TOP": TokenType.TOP, + "UPD": TokenType.UPDATE, } # Teradata does not support % as a modulo operator @@ -181,6 +185,13 @@ class Teradata(Dialect): exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"): + # We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>) + expression.to.pop() + + return super().cast_sql(expression, safe_prefix=safe_prefix) + def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 867e4e4..a281297 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import ( generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, - move_insert_cte_sql, parse_date_delta, rename_func, timestrtotime_sql, @@ -158,8 +157,6 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: - expression = expression.copy() - this = expression.this distinct = expression.find(exp.Distinct) if distinct: @@ -246,6 +243,7 @@ class TSQL(Dialect): "MMM": "%b", "MM": "%m", "M": "%-m", + "dddd": "%A", "dd": "%d", "d": "%-d", "HH": "%H", @@ -596,6 +594,8 @@ class TSQL(Dialect): ALTER_TABLE_ADD_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" COMPUTED_COLUMN_WITH_TYPE = False + SUPPORTS_NESTED_CTES = False + CTE_RECURSIVE_KEYWORD_REQUIRED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -622,7 +622,6 @@ class TSQL(Dialect): exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), - exp.Insert: move_insert_cte_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, @@ -685,7 +684,6 @@ class TSQL(Dialect): return sql def create_sql(self, expression: exp.Create) -> str: - expression = expression.copy() kind = self.sql(expression, "kind").upper() exists = expression.args.pop("exists", None) sql = super().create_sql(expression) @@ -714,7 +712,7 @@ class TSQL(Dialect): elif expression.args.get("replace"): sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) - return sql + return self.prepend_ctes(expression, sql) def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 5b012b1..99ebfb3 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2145,6 +2145,22 @@ class PartitionedByProperty(Property): arg_types = {"this": True} +# https://www.postgresql.org/docs/current/sql-createtable.html +class PartitionBoundSpec(Expression): + # this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...) + arg_types = { + "this": False, + "expression": False, + "from_expressions": False, + "to_expressions": False, + } + + +class PartitionedOfProperty(Property): + # this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT + arg_types = {"this": True, "expression": True} + + class RemoteWithConnectionModelProperty(Property): arg_types = {"this": True} @@ -2486,6 +2502,7 @@ class Table(Expression): "format": False, "pattern": False, "index": False, + "ordinality": False, } @property @@ -2649,11 +2666,7 @@ class Update(Expression): class Values(UDTF): - arg_types = { - "expressions": True, - "ordinality": False, - "alias": False, - } + arg_types = {"expressions": True, "alias": False} class Var(Expression): @@ -3501,7 +3514,7 @@ class Star(Expression): class Parameter(Condition): - arg_types = {"this": True, "wrapped": False} + arg_types = {"this": True, "expression": False} class SessionParameter(Condition): @@ -5036,7 +5049,7 @@ class FromBase(Func): class Struct(Func): - arg_types = {"expressions": True} + arg_types = {"expressions": False} is_var_len_args = True @@ -5171,7 +5184,7 @@ class Use(Expression): class Merge(Expression): - arg_types = {"this": True, "using": True, "on": True, "expressions": True} + arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False} class When(Func): @@ -5459,7 +5472,12 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: def union( - left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + left: ExpOrStr, + right: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, ) -> Union: """ Initializes a syntax tree from one UNION expression. @@ -5475,19 +5493,25 @@ def union( If an `Expression` instance is passed, it will be used as-is. distinct: set the DISTINCT flag if and only if this is true. dialect: the dialect used to parse the input expression. + copy: whether or not to copy the expression. opts: other options to use to parse the input expressions. Returns: The new Union instance. """ - left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) - right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts) return Union(this=left, expression=right, distinct=distinct) def intersect( - left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + left: ExpOrStr, + right: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, ) -> Intersect: """ Initializes a syntax tree from one INTERSECT expression. @@ -5503,19 +5527,25 @@ def intersect( If an `Expression` instance is passed, it will be used as-is. distinct: set the DISTINCT flag if and only if this is true. dialect: the dialect used to parse the input expression. + copy: whether or not to copy the expression. opts: other options to use to parse the input expressions. Returns: The new Intersect instance. """ - left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) - right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts) return Intersect(this=left, expression=right, distinct=distinct) def except_( - left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + left: ExpOrStr, + right: ExpOrStr, + distinct: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, ) -> Except: """ Initializes a syntax tree from one EXCEPT expression. @@ -5531,13 +5561,14 @@ def except_( If an `Expression` instance is passed, it will be used as-is. distinct: set the DISTINCT flag if and only if this is true. dialect: the dialect used to parse the input expression. + copy: whether or not to copy the expression. opts: other options to use to parse the input expressions. Returns: The new Except instance. """ - left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) - right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts) return Except(this=left, expression=right, distinct=distinct) @@ -5861,7 +5892,7 @@ def to_identifier(name, quoted=None, copy=True): Args: name: The name to turn into an identifier. quoted: Whether or not force quote the identifier. - copy: Whether or not to copy a passed in Identefier node. + copy: Whether or not to copy name if it's an Identifier. Returns: The identifier ast node. @@ -5882,6 +5913,25 @@ def to_identifier(name, quoted=None, copy=True): return identifier +def parse_identifier(name: str, dialect: DialectType = None) -> Identifier: + """ + Parses a given string into an identifier. + + Args: + name: The name to parse into an identifier. + dialect: The dialect to parse against. + + Returns: + The identifier ast node. + """ + try: + expression = maybe_parse(name, dialect=dialect, into=Identifier) + except ParseError: + expression = to_identifier(name) + + return expression + + INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*") diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0d6778a..4916cf8 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -230,6 +230,12 @@ class Generator: # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle) DATA_TYPE_SPECIFIERS_ALLOWED = False + # Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed + SUPPORTS_NESTED_CTES = True + + # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs + CTE_RECURSIVE_KEYWORD_REQUIRED = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -304,6 +310,7 @@ class Generator: exp.Order: exp.Properties.Location.POST_SCHEMA, exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA, exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, exp.Property: exp.Properties.Location.POST_WITH, exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA, @@ -407,7 +414,6 @@ class Generator: "unsupported_messages", "_escaped_quote_end", "_escaped_identifier_end", - "_cache", ) def __init__( @@ -447,30 +453,38 @@ class Generator: self._escaped_identifier_end: str = ( self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END ) - self._cache: t.Optional[t.Dict[int, str]] = None - def generate( - self, - expression: t.Optional[exp.Expression], - cache: t.Optional[t.Dict[int, str]] = None, - ) -> str: + def generate(self, expression: exp.Expression, copy: bool = True) -> str: """ Generates the SQL string corresponding to the given syntax tree. Args: expression: The syntax tree. - cache: An optional sql string cache. This leverages the hash of an Expression - which can be slow to compute, so only use it if you set _hash on each node. + copy: Whether or not to copy the expression. The generator performs mutations so + it is safer to copy. Returns: The SQL string corresponding to `expression`. """ - if cache is not None: - self._cache = cache + if copy: + expression = expression.copy() + + # Some dialects only support CTEs at the top level expression, so we need to bubble up nested + # CTEs to that level in order to produce a syntactically valid expression. This transformation + # happens here to minimize code duplication, since many expressions support CTEs. + if ( + not self.SUPPORTS_NESTED_CTES + and isinstance(expression, exp.Expression) + and not expression.parent + and "with" in expression.arg_types + and any(node.parent is not expression for node in expression.find_all(exp.With)) + ): + from sqlglot.transforms import move_ctes_to_top_level + + expression = move_ctes_to_top_level(expression) self.unsupported_messages = [] sql = self.sql(expression).strip() - self._cache = None if self.unsupported_level == ErrorLevel.IGNORE: return sql @@ -595,12 +609,6 @@ class Generator: return self.sql(value) return "" - if self._cache is not None: - expression_id = hash(expression) - - if expression_id in self._cache: - return self._cache[expression_id] - transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): @@ -621,11 +629,7 @@ class Generator: else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - sql = self.maybe_comment(sql, expression) if self.comments and comment else sql - - if self._cache is not None: - self._cache[expression_id] = sql - return sql + return self.maybe_comment(sql, expression) if self.comments and comment else sql def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") @@ -879,7 +883,11 @@ class Generator: def with_sql(self, expression: exp.With) -> str: sql = self.expressions(expression, flat=True) - recursive = "RECURSIVE " if expression.args.get("recursive") else "" + recursive = ( + "RECURSIVE " + if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive") + else "" + ) return f"WITH {recursive}{sql}" @@ -1022,7 +1030,7 @@ class Generator: where = self.sql(expression, "expression").strip() return f"{this} FILTER({where})" - agg = expression.this.copy() + agg = expression.this agg_arg = agg.this cond = expression.expression.this agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy())) @@ -1088,9 +1096,9 @@ class Generator: for p in expression.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] if p_loc == exp.Properties.Location.POST_WITH: - with_properties.append(p.copy()) + with_properties.append(p) elif p_loc == exp.Properties.Location.POST_SCHEMA: - root_properties.append(p.copy()) + root_properties.append(p) return self.root_properties( exp.Properties(expressions=root_properties) @@ -1124,7 +1132,7 @@ class Generator: for p in properties.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] if p_loc != exp.Properties.Location.UNSUPPORTED: - properties_locs[p_loc].append(p.copy()) + properties_locs[p_loc].append(p) else: self.unsupported(f"Unsupported property {p.key}") @@ -1238,6 +1246,29 @@ class Generator: for_ = " FOR NONE" return f"WITH{no}{concurrent} ISOLATED LOADING{for_}" + def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str: + if isinstance(expression.this, list): + return f"IN ({self.expressions(expression, key='this', flat=True)})" + if expression.this: + modulus = self.sql(expression, "this") + remainder = self.sql(expression, "expression") + return f"WITH (MODULUS {modulus}, REMAINDER {remainder})" + + from_expressions = self.expressions(expression, key="from_expressions", flat=True) + to_expressions = self.expressions(expression, key="to_expressions", flat=True) + return f"FROM ({from_expressions}) TO ({to_expressions})" + + def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str: + this = self.sql(expression, "this") + + for_values_or_default = expression.expression + if isinstance(for_values_or_default, exp.PartitionBoundSpec): + for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}" + else: + for_values_or_default = " DEFAULT" + + return f"PARTITION OF {this}{for_values_or_default}" + def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: kind = expression.args.get("kind") this = f" {self.sql(expression, 'this')}" if expression.this else "" @@ -1385,7 +1416,12 @@ class Generator: index = self.sql(expression, "index") index = f" AT {index}" if index else "" - return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}" + ordinality = expression.args.get("ordinality") or "" + if ordinality: + ordinality = f" WITH ORDINALITY{alias}" + alias = "" + + return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1489,7 +1525,6 @@ class Generator: return f"{values} AS {alias}" if alias else values # Converts `VALUES...` expression into a series of select unions. - expression = expression.copy() alias_node = expression.args.get("alias") column_names = alias_node and alias_node.columns @@ -1972,8 +2007,7 @@ class Generator: if self.UNNEST_WITH_ORDINALITY: if alias and isinstance(offset, exp.Expression): - alias = alias.copy() - alias.append("columns", offset.copy()) + alias.append("columns", offset) if alias and self.UNNEST_COLUMN_ONLY: columns = alias.columns @@ -2138,7 +2172,6 @@ class Generator: return f"PRIMARY KEY ({expressions}){options}" def if_sql(self, expression: exp.If) -> str: - expression = expression.copy() return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: @@ -2367,7 +2400,9 @@ class Generator: def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: format_sql = self.sql(expression, "format") format_sql = f" FORMAT {format_sql}" if format_sql else "" - return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})" + to_sql = self.sql(expression, "to") + to_sql = f" {to_sql}" if to_sql else "" + return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})" def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") @@ -2510,7 +2545,7 @@ class Generator: def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( - this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()), + this=exp.Div(this=expression.this, expression=expression.expression), to=exp.DataType(this=exp.DataType.Type.INT), ) ) @@ -2779,7 +2814,6 @@ class Generator: hints = table.args.get("hints") if hints and table.alias and isinstance(hints[0], exp.WithTableHint): # T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias] - table = table.copy() table_alias = f" AS {self.sql(table.args['alias'].pop())}" this = self.sql(table) @@ -2787,7 +2821,9 @@ class Generator: on = f"ON {self.sql(expression, 'on')}" expressions = self.expressions(expression, sep=" ") - return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}" + return self.prepend_ctes( + expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}" + ) def tochar_sql(self, expression: exp.ToChar) -> str: if expression.args.get("format"): @@ -2896,12 +2932,12 @@ class Generator: case = exp.Case().when( expression.this.is_(exp.null()).not_(copy=False), - expression.args["true"].copy(), + expression.args["true"], copy=False, ) else_cond = expression.args.get("false") if else_cond: - case.else_(else_cond.copy(), copy=False) + case.else_(else_cond, copy=False) return self.sql(case) @@ -2931,15 +2967,6 @@ class Generator: if not isinstance(expression, exp.Literal): from sqlglot.optimizer.simplify import simplify - expression = simplify(expression.copy()) + expression = simplify(expression) return expression - - -def cached_generator( - cache: t.Optional[t.Dict[int, str]] = None -) -> t.Callable[[exp.Expression], str]: - """Returns a cached generator.""" - cache = {} if cache is None else cache - generator = Generator(normalize=True, identify="safe") - return lambda e: generator.generate(e, cache) diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 74b61e3..ee41557 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -184,9 +184,7 @@ def apply_index_offset( annotate_types(expression) if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: logger.warning("Applying array index offset (%s)", offset) - expression = simplify( - exp.Add(this=expression.copy(), expression=exp.Literal.number(offset)) - ) + expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset))) return [expression] return expressions diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 8d82b2d..6df36af 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -4,7 +4,6 @@ import logging from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.generator import cached_generator from sqlglot.helper import while_changing from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort @@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - generate = cached_generator() - for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): @@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) ) except OptimizeError as e: logger.info(e) @@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, generate): +def distributive_law(expression, dnf, max_distance): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, generate) - return _distribute(b, a, from_func, to_func, generate) + return _distribute(a, b, from_func, to_func) + return _distribute(b, a, from_func, to_func) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, generate) + return _distribute(a, b, from_func, to_func) return expression -def _distribute(a, b, from_func, to_func, generate): +def _distribute(a, b, from_func, to_func): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), generate), - uniq_sort(flatten(from_func(c, b.right)), generate), + uniq_sort(flatten(from_func(c, b.left))), + uniq_sort(flatten(from_func(c, b.right))), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), generate), - uniq_sort(flatten(from_func(a, b.right)), generate), + uniq_sort(flatten(from_func(a, b.left))), + uniq_sort(flatten(from_func(a, b.right))), copy=False, ) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index ecea6a0..154256e 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, parse_one +from sqlglot import exp from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType @@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None): The transformed expression. """ if isinstance(expression, str): - expression = parse_one(expression, dialect=dialect, into=exp.Identifier) + expression = exp.parse_identifier(expression, dialect=dialect) dialect = Dialect.get_or_raise(dialect) diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 68aebdb..3a43e8f 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -62,7 +62,7 @@ def qualify_tables( if isinstance(source.this, exp.Identifier): if not source.args.get("db"): source.set("db", exp.to_identifier(db)) - if not source.args.get("catalog"): + if not source.args.get("catalog") and source.args.get("db"): source.set("catalog", exp.to_identifier(catalog)) if not source.alias: diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 30de75b..af03332 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -7,8 +7,7 @@ from decimal import Decimal import sqlglot from sqlglot import exp -from sqlglot.generator import cached_generator -from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.helper import first, is_iterable, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope # Final means that an expression should not be simplified @@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False): sqlglot.Expression: simplified expression """ - generate = cached_generator() - # group by expressions cannot be simplified, for example # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 # the projection must exactly match the group by key @@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False): # Pre-order transformations node = expression node = rewrite_between(node) - node = uniq_sort(node, generate, root) + node = uniq_sort(node, root) node = absorb_and_eliminate(node, root) node = simplify_concat(node) node = simplify_conditionals(node) @@ -311,7 +308,7 @@ def remove_complements(expression, root=True): return expression -def uniq_sort(expression, generate, root=True): +def uniq_sort(expression, root=True): """ Uniq and sort a connector. @@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True): if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {generate(e): e for e in flattened} + deduped = {gen(e): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them @@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True): lambda a, b: expression.__class__(this=a, expression=b), operands ) return expression + + +def gen(expression: t.Any) -> str: + """Simple pseudo sql generator for quickly generating sortable and uniq strings. + + Sorting and deduping sql is a necessary step for optimization. Calling the actual + generator is expensive so we have a bare minimum sql generator here. + """ + if expression is None: + return "_" + if is_iterable(expression): + return ",".join(gen(e) for e in expression) + if not isinstance(expression, exp.Expression): + return str(expression) + + etype = type(expression) + if etype in GEN_MAP: + return GEN_MAP[etype](expression) + return f"{expression.key} {gen(expression.args.values())}" + + +GEN_MAP = { + exp.Add: lambda e: _binary(e, "+"), + exp.And: lambda e: _binary(e, "AND"), + exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}", + exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", + exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", + exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", + exp.Column: lambda e: ".".join(gen(p) for p in e.parts), + exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", + exp.Div: lambda e: _binary(e, "/"), + exp.Dot: lambda e: _binary(e, "."), + exp.DPipe: lambda e: _binary(e, "||"), + exp.SafeDPipe: lambda e: _binary(e, "||"), + exp.EQ: lambda e: _binary(e, "="), + exp.GT: lambda e: _binary(e, ">"), + exp.GTE: lambda e: _binary(e, ">="), + exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, + exp.ILike: lambda e: _binary(e, "ILIKE"), + exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", + exp.Is: lambda e: _binary(e, "IS"), + exp.Like: lambda e: _binary(e, "LIKE"), + exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, + exp.LT: lambda e: _binary(e, "<"), + exp.LTE: lambda e: _binary(e, "<="), + exp.Mod: lambda e: _binary(e, "%"), + exp.Mul: lambda e: _binary(e, "*"), + exp.Neg: lambda e: _unary(e, "-"), + exp.NEQ: lambda e: _binary(e, "<>"), + exp.Not: lambda e: _unary(e, "NOT"), + exp.Null: lambda e: "NULL", + exp.Or: lambda e: _binary(e, "OR"), + exp.Paren: lambda e: f"({gen(e.this)})", + exp.Sub: lambda e: _binary(e, "-"), + exp.Subquery: lambda e: f"({gen(e.args.values())})", + exp.Table: lambda e: gen(e.args.values()), + exp.Var: lambda e: e.name, +} + + +def _binary(e: exp.Binary, op: str) -> str: + return f"{gen(e.left)} {op} {gen(e.right)}" + + +def _unary(e: exp.Unary, op: str) -> str: + return f"{op} {gen(e.this)}" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b7f91ab..1dab600 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -674,6 +674,7 @@ class Parser(metaclass=_Parser): "ON": lambda self: self._parse_on_property(), "ORDER BY": lambda self: self._parse_order(skip_order_token=True), "OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()), + "PARTITION": lambda self: self._parse_partitioned_of(), "PARTITION BY": lambda self: self._parse_partitioned_by(), "PARTITIONED BY": lambda self: self._parse_partitioned_by(), "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), @@ -1743,6 +1744,58 @@ class Parser(metaclass=_Parser): return self._parse_csv(self._parse_conjunction) return [] + def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec: + def _parse_partition_bound_expr() -> t.Optional[exp.Expression]: + if self._match_text_seq("MINVALUE"): + return exp.var("MINVALUE") + if self._match_text_seq("MAXVALUE"): + return exp.var("MAXVALUE") + return self._parse_bitwise() + + this: t.Optional[exp.Expression | t.List[exp.Expression]] = None + expression = None + from_expressions = None + to_expressions = None + + if self._match(TokenType.IN): + this = self._parse_wrapped_csv(self._parse_bitwise) + elif self._match(TokenType.FROM): + from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr) + self._match_text_seq("TO") + to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr) + elif self._match_text_seq("WITH", "(", "MODULUS"): + this = self._parse_number() + self._match_text_seq(",", "REMAINDER") + expression = self._parse_number() + self._match_r_paren() + else: + self.raise_error("Failed to parse partition bound spec.") + + return self.expression( + exp.PartitionBoundSpec, + this=this, + expression=expression, + from_expressions=from_expressions, + to_expressions=to_expressions, + ) + + # https://www.postgresql.org/docs/current/sql-createtable.html + def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]: + if not self._match_text_seq("OF"): + self._retreat(self._index - 1) + return None + + this = self._parse_table(schema=True) + + if self._match(TokenType.DEFAULT): + expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT") + elif self._match_text_seq("FOR", "VALUES"): + expression = self._parse_partition_bound_spec() + else: + self.raise_error("Expecting either DEFAULT or FOR VALUES clause.") + + return self.expression(exp.PartitionedOfProperty, this=this, expression=expression) + def _parse_partitioned_by(self) -> exp.PartitionedByProperty: self._match(TokenType.EQ) return self.expression( @@ -2682,6 +2735,10 @@ class Parser(metaclass=_Parser): for join in iter(self._parse_join, None): this.append("joins", join) + if self._match_pair(TokenType.WITH, TokenType.ORDINALITY): + this.set("ordinality", True) + this.set("alias", self._parse_table_alias()) + return this def _parse_version(self) -> t.Optional[exp.Version]: @@ -4189,17 +4246,12 @@ class Parser(metaclass=_Parser): fmt = None to = self._parse_types() - if not to: - self.raise_error("Expected TYPE after CAST") - elif isinstance(to, exp.Identifier): - to = exp.DataType.build(to.name, udt=True) - elif to.this == exp.DataType.Type.CHAR: - if self._match(TokenType.CHARACTER_SET): - to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) - elif self._match(TokenType.FORMAT): + if self._match(TokenType.FORMAT): fmt_string = self._parse_string() fmt = self._parse_at_time_zone(fmt_string) + if not to: + to = exp.DataType.build(exp.DataType.Type.UNKNOWN) if to.this in exp.DataType.TEMPORAL_TYPES: this = self.expression( exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, @@ -4215,8 +4267,14 @@ class Parser(metaclass=_Parser): if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime): this.set("zone", fmt.args["zone"]) - return this + elif not to: + self.raise_error("Expected TYPE after CAST") + elif isinstance(to, exp.Identifier): + to = exp.DataType.build(to.name, udt=True) + elif to.this == exp.DataType.Type.CHAR: + if self._match(TokenType.CHARACTER_SET): + to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) return self.expression( exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe @@ -4789,10 +4847,17 @@ class Parser(metaclass=_Parser): return self._parse_placeholder() def _parse_parameter(self) -> exp.Parameter: - wrapped = self._match(TokenType.L_BRACE) - this = self._parse_var() or self._parse_identifier() or self._parse_primary() + def _parse_parameter_part() -> t.Optional[exp.Expression]: + return ( + self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True) + ) + + self._match(TokenType.L_BRACE) + this = _parse_parameter_part() + expression = self._match(TokenType.COLON) and _parse_parameter_part() self._match(TokenType.R_BRACE) - return self.expression(exp.Parameter, this=this, wrapped=wrapped) + + return self.expression(exp.Parameter, this=this, expression=expression) def _parse_placeholder(self) -> t.Optional[exp.Expression]: if self._match_set(self.PLACEHOLDER_PARSERS): diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 778378c..acf9bc4 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -3,10 +3,9 @@ from __future__ import annotations import abc import typing as t -import sqlglot from sqlglot import expressions as exp from sqlglot.dialects.dialect import Dialect -from sqlglot.errors import ParseError, SchemaError +from sqlglot.errors import SchemaError from sqlglot.helper import dict_depth from sqlglot.trie import TrieResult, in_trie, new_trie @@ -448,19 +447,16 @@ class MappingSchema(AbstractMappingSchema, Schema): def normalize_name( - name: str | exp.Identifier, + identifier: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False, normalize: t.Optional[bool] = True, ) -> str: - try: - identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) - except ParseError: - return name if isinstance(name, str) else name.name + if isinstance(identifier, str): + identifier = exp.parse_identifier(identifier, dialect=dialect) - name = identifier.name if not normalize: - return name + return identifier.name # This can be useful for normalize_identifier identifier.meta["is_table"] = is_table diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index e0fd68f..445fda6 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -67,7 +67,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: order = expression.args.get("order") if order: - window.set("order", order.pop().copy()) + window.set("order", order.pop()) else: window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) @@ -75,9 +75,9 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: expression.select(window, copy=False) return ( - exp.select(*outer_selects) - .from_(expression.subquery("_t")) - .where(exp.column(row_number).eq(1)) + exp.select(*outer_selects, copy=False) + .from_(expression.subquery("_t", copy=False), copy=False) + .where(exp.column(row_number).eq(1), copy=False) ) return expression @@ -120,7 +120,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: elif expr.name not in expression.named_selects: expression.select(expr.copy(), copy=False) - return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) + return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( + qualify_filters, copy=False + ) return expression @@ -189,7 +191,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp ) # we use list here because expression.selects is mutated inside the loop - for select in expression.selects.copy(): + for select in list(expression.selects): explode = select.find(exp.Explode) if explode: @@ -374,6 +376,60 @@ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: return expression +def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: + """ + Converts a query with a FULL OUTER join to a union of identical queries that + use LEFT/RIGHT OUTER joins instead. This transformation currently only works + for queries that have a single FULL OUTER join. + """ + if isinstance(expression, exp.Select): + full_outer_joins = [ + (index, join) + for index, join in enumerate(expression.args.get("joins") or []) + if join.side == "FULL" and join.kind == "OUTER" + ] + + if len(full_outer_joins) == 1: + expression_copy = expression.copy() + index, full_outer_join = full_outer_joins[0] + full_outer_join.set("side", "left") + expression_copy.args["joins"][index].set("side", "right") + + return exp.union(expression, expression_copy, copy=False) + + return expression + + +def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: + """ + Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be + defined at the top-level, so for example queries like: + + SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq + + are invalid in those dialects. This transformation can be used to ensure all CTEs are + moved to the top level so that the final SQL code is valid from a syntax standpoint. + + TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). + """ + top_level_with = expression.args.get("with") + for node in expression.find_all(exp.With): + if node.parent is expression: + continue + + inner_with = node.pop() + if not top_level_with: + top_level_with = inner_with + expression.set("with", top_level_with) + else: + if inner_with.recursive: + top_level_with.set("recursive", True) + + top_level_with.expressions.extend(inner_with.expressions) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: @@ -392,7 +448,7 @@ def preprocess( def _to_sql(self, expression: exp.Expression) -> str: expression_type = type(expression) - expression = transforms[0](expression.copy()) + expression = transforms[0](expression) for t in transforms[1:]: expression = t(expression) |