diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 29 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 17 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 18 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 10 |
17 files changed, 101 insertions, 84 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 51baba2..fc9a3ae 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -69,7 +69,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: returns = expression.find(exp.ReturnsProperty) if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): - expression = expression.copy() expression.set("kind", "TABLE FUNCTION") if isinstance(expression.expression, (exp.Subquery, exp.Literal)): @@ -699,6 +698,5 @@ class BigQuery(Dialect): def version_sql(self, expression: exp.Version) -> str: if expression.name == "TIMESTAMP": - expression = expression.copy() expression.set("this", "SYSTEM_TIME") return super().version_sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 30f728c..394a922 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -461,7 +461,6 @@ class ClickHouse(Dialect): def safeconcat_sql(self, expression: exp.SafeConcat) -> str: # Clickhouse errors out if we try to cast a NULL value to TEXT - expression = expression.copy() return self.func( "CONCAT", *[ diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 314a821..b777db0 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -35,7 +35,7 @@ class Databricks(Spark): exp.DatetimeSub: lambda self, e: self.func( "TIMESTAMPADD", e.text("unit"), - exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)), + exp.Mul(this=e.expression, expression=exp.Literal.number(-1)), e.this, ), exp.DatetimeDiff: lambda self, e: self.func( @@ -63,21 +63,14 @@ class Databricks(Spark): and kind.this in exp.DataType.INTEGER_TYPES ): # only BIGINT generated identity constraints are supported - expression = expression.copy() expression.set("kind", exp.DataType.build("bigint")) return super().columndef_sql(expression, sep) def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: - expression = expression.copy() expression.set("this", True) # trigger ALWAYS in super class return super().generatedasidentitycolumnconstraint_sql(expression) class Tokenizer(Spark.Tokenizer): HEX_STRINGS = [] - - SINGLE_TOKENS = { - **Spark.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 739e8d7..21e7889 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -315,11 +315,14 @@ class Dialect(metaclass=_Dialect): ) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) - def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: - return self.generator(**opts).generate(expression) + def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: + return self.generator(**opts).generate(expression, copy=copy) def transpile(self, sql: str, **opts) -> t.List[str]: - return [self.generate(expression, **opts) for expression in self.parse(sql)] + return [ + self.generate(expression, copy=False, **opts) if expression else "" + for expression in self.parse(sql) + ] def tokenize(self, sql: str) -> t.List[Token]: return self.tokenizer.tokenize(sql) @@ -380,9 +383,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like( - this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() - ) + exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) ) @@ -518,7 +519,6 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") if has_schema and is_partitionable: - expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) if prop and prop.this and not isinstance(prop.this, exp.Schema): schema = expression.this @@ -583,7 +583,7 @@ def date_add_interval_sql( this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression.copy(), unit=unit) + interval = exp.Interval(this=expression.expression, unit=unit) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func @@ -621,7 +621,6 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: - expression = expression.copy() return self.sql( exp.Substring( this=expression.this, start=exp.Literal.number(1), length=expression.expression @@ -630,7 +629,6 @@ def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: - expression = expression.copy() return self.sql( exp.Substring( this=expression.this, @@ -675,7 +673,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this.expressions[0] self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") - return self.func("sum", exp.func("if", cond.copy(), 1, 0)) + return self.func("sum", exp.func("if", cond, 1, 0)) def trim_sql(self: Generator, expression: exp.Trim) -> str: @@ -716,12 +714,10 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: - expression = expression.copy() return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: - expression = expression.copy() delim, *rest_args = expression.expressions return self.sql( reduce( @@ -809,13 +805,6 @@ def isnull_to_is_null(args: t.List) -> exp.Expression: return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) -def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: - if expression.expression.args.get("with"): - expression = expression.copy() - expression.set("with", expression.expression.args["with"].pop()) - return self.insert_sql(expression) - - def generatedasidentitycolumnconstraint_sql( self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 8b2e708..42453fd 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -20,7 +20,9 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") - return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -147,7 +149,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 287e03a..d8d9f90 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -36,14 +36,14 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" + return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" op = "+" if isinstance(expression, exp.DateAdd) else "-" - return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" + return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" # BigQuery -> DuckDB conversion for the DATE function @@ -365,7 +365,7 @@ class DuckDB(Dialect): multiplier = 90 if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})" + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" return super().interval_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 7bff553..3b1c8de 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -53,8 +53,6 @@ DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") def _create_sql(self, expression: exp.Create) -> str: - expression = expression.copy() - # remove UNIQUE column constraints for constraint in expression.find_all(exp.UniqueColumnConstraint): if constraint.parent: @@ -88,7 +86,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) - if expression.expression.is_number: modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) else: - modified_increment = expression.expression.copy() + modified_increment = expression.expression if multiplier != 1: modified_increment = exp.Mul( # type: ignore this=modified_increment, expression=exp.Literal.number(multiplier) @@ -229,6 +227,11 @@ class Hive(Dialect): STRING_ESCAPES = ["\\"] ENCODE = "utf-8" + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ADD ARCHIVE": TokenType.COMMAND, @@ -408,6 +411,7 @@ class Hive(Dialect): INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False + SUPPORTS_NESTED_CTES = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -521,7 +525,10 @@ class Hive(Dialect): def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + parent = expression.parent + this = f"{this}:{expression_sql}" if expression_sql else this if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem): # We need to produce SET key = value instead of SET ${key} = value @@ -530,8 +537,6 @@ class Hive(Dialect): return f"${{{this}}}" def schema_sql(self, expression: exp.Schema) -> str: - expression = expression.copy() - for ordered in expression.find_all(exp.Ordered): if ordered.args.get("desc") is False: ordered.set("desc", None) @@ -539,8 +544,6 @@ class Hive(Dialect): return super().schema_sql(expression) def constraint_sql(self, expression: exp.Constraint) -> str: - expression = expression.copy() - for prop in list(expression.find_all(exp.Properties)): prop.pop() diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 2185a85..c78aa9e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: return f"STR_TO_DATE({concat}, '{date_format}')" -def _str_to_date(args: t.List) -> exp.StrToDate: - date_format = MySQL.format_time(seq_get(args, 1)) - return exp.StrToDate(this=seq_get(args, 0), format=date_format) +# All specifiers for time parts (as opposed to date parts) +# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format +TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"} + + +def _has_time_specifier(date_format: str) -> bool: + i = 0 + length = len(date_format) + + while i < length: + if date_format[i] == "%": + i += 1 + if i < length and date_format[i] in TIME_SPECIFIERS: + return True + i += 1 + return False + + +def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime: + mysql_date_format = seq_get(args, 1) + date_format = MySQL.format_time(mysql_date_format) + this = seq_get(args, 0) + + if mysql_date_format and _has_time_specifier(mysql_date_format.name): + return exp.StrToTime(this=this, format=date_format) + + return exp.StrToDate(this=this, format=date_format) def _str_to_date_sql( @@ -93,7 +117,9 @@ def _date_add_sql( def func(self: MySQL.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date( args: t.Tuple[str, ...] = ("this",), ) -> t.Callable[[MySQL.Generator, exp.Func], str]: def func(self: MySQL.Generator, expression: exp.Func) -> str: - expression = expression.copy() - for arg_key in args: arg = expression.args.get(arg_key) if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): @@ -629,6 +653,7 @@ class MySQL(Dialect): transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins, transforms.eliminate_qualify, + transforms.eliminate_full_outer_join, ] ), exp.StrPosition: strposition_to_locate_sql, @@ -728,7 +753,6 @@ class MySQL(Dialect): to = self.CAST_MAPPING.get(expression.to.this) if to: - expression = expression.copy() expression.to.set("this", to) return super().cast_sql(expression) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 086b278..27c6851 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -43,8 +43,6 @@ DATE_DIFF_FACTOR = { def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - expression = expression.copy() - this = self.sql(expression, "this") unit = expression.args.get("unit") @@ -96,7 +94,6 @@ def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str: def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str: - expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") order = "" @@ -119,7 +116,6 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: auto = expression.find(exp.AutoIncrementColumnConstraint) if auto: - expression = expression.copy() expression.args["constraints"].remove(auto.parent) kind = expression.args["kind"] @@ -134,7 +130,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: def _serial_to_generated(expression: exp.Expression) -> exp.Expression: - kind = expression.args["kind"] + kind = expression.args.get("kind") + if not kind: + return expression if kind.this == exp.DataType.Type.SERIAL: data_type = exp.DataType(this=exp.DataType.Type.INT) @@ -146,7 +144,6 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression: data_type = None if data_type: - expression = expression.copy() expression.args["kind"].replace(data_type) constraints = expression.args["constraints"] generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) @@ -409,6 +406,7 @@ class Postgres(Dialect): exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, exp.Merge: transforms.preprocess([_remove_target_from_merge]), + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess( [transforms.add_within_group_for_percentiles] ), @@ -445,6 +443,7 @@ class Postgres(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -452,7 +451,6 @@ class Postgres(Dialect): def bracket_sql(self, expression: exp.Bracket) -> str: """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY.""" if isinstance(expression.this, exp.Array): - expression = expression.copy() expression.set("this", exp.paren(expression.this, copy=False)) return super().bracket_sql(expression) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index aac368c..ded3655 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -36,7 +36,6 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, exp.Explode): - expression = expression.copy() return self.sql( exp.Join( this=exp.Unnest( @@ -72,7 +71,6 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: for schema in expression.parent.find_all(exp.Schema): column_defs = schema.find_all(exp.ColumnDef) if column_defs and isinstance(schema.parent, exp.Property): - expression = expression.copy() expression.expressions.extend(column_defs) return self.schema_sql(expression) @@ -407,12 +405,10 @@ class Presto(Dialect): target_type = None if target_type and target_type.is_type("timestamp"): - to = target_type.copy() - if target_type is start.to: - end = exp.cast(end, to) + end = exp.cast(end, target_type) else: - start = exp.cast(start, to) + start = exp.cast(start, target_type) return self.func("SEQUENCE", start, end, step) @@ -432,6 +428,5 @@ class Presto(Dialect): kind = expression.args["kind"] schema = expression.this if kind == "VIEW" and schema.expressions: - expression = expression.copy() expression.this.set("expressions", None) return super().create_sql(expression) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index df70aa7..6c7ba35 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -27,6 +27,14 @@ def _parse_date_add(args: t.List) -> exp.DateAdd: ) +def _parse_datediff(args: t.List) -> exp.DateDiff: + return exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 2)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), + unit=seq_get(args, 0), + ) + + class Redshift(Postgres): # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -51,11 +59,9 @@ class Redshift(Postgres): ), "DATEADD": _parse_date_add, "DATE_ADD": _parse_date_add, - "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 2)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), - unit=seq_get(args, 0), - ), + "DATEDIFF": _parse_datediff, + "DATE_DIFF": _parse_datediff, + "LISTAGG": exp.GroupConcat.from_arg_list, "STRTOL": exp.FromBase.from_arg_list, } @@ -175,6 +181,7 @@ class Redshift(Postgres): exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, + exp.GroupConcat: rename_func("LISTAGG"), exp.ParseJSON: rename_func("JSON_PARSE"), exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( @@ -207,7 +214,6 @@ class Redshift(Postgres): `TEXT` to `VARCHAR`. """ if expression.is_type("text"): - expression = expression.copy() expression.set("this", exp.DataType.Type.VARCHAR) precision = expression.args.get("expressions") diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 07be65b..01f7512 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -32,7 +32,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -60,8 +60,8 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: # reduce it using `simplify_literals` first and then check if it's a Literal. first_arg = seq_get(args, 0) if not isinstance(simplify_literals(first_arg, root=True), Literal): - # case: <variant_expr> - return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + # case: <variant_expr> or other expressions such as columns + return exp.TimeStrToTime.from_arg_list(args) if first_arg.is_string: if _check_int(first_arg.this): @@ -560,7 +560,6 @@ class Snowflake(Dialect): offset = expression.args.get("offset") if offset: if unnest_alias: - expression = expression.copy() unnest_alias.append("columns", offset.pop()) selects.append("index") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 8461920..1abfce6 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -63,6 +63,8 @@ class Spark(Spark2): return this class Generator(Spark2.Generator): + SUPPORTS_NESTED_CTES = True + TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, exp.DataType.Type.MONEY: "DECIMAL(15, 4)", diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 2fd4f4e..da84bd8 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import ( binary_from_function, format_time_lambda, is_parse_json, - move_insert_cte_sql, pivot_column_names, rename_func, trim_sql, @@ -70,7 +69,9 @@ def _unalias_pivot(expression: exp.Expression) -> exp.Expression: alias = pivot.args["alias"].pop() return exp.From( this=expression.this.replace( - exp.select("*").from_(expression.this.copy()).subquery(alias=alias) + exp.select("*") + .from_(expression.this.copy(), copy=False) + .subquery(alias=alias, copy=False) ) ) @@ -188,7 +189,6 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.Insert: move_insert_cte_sql, exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 1edfa9d..1fa730d 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -50,7 +50,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: else: for column in defs.values(): auto_increment = None - for constraint in column.constraints.copy(): + for constraint in column.constraints: if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint): break if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 152afa6..e8162c2 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -38,12 +38,15 @@ class Teradata(Dialect): "^=": TokenType.NEQ, "BYTEINT": TokenType.SMALLINT, "COLLECT": TokenType.COMMAND, + "DEL": TokenType.DELETE, + "EQ": TokenType.EQ, "GE": TokenType.GTE, "GT": TokenType.GT, "HELP": TokenType.COMMAND, "INS": TokenType.INSERT, "LE": TokenType.LTE, "LT": TokenType.LT, + "MINUS": TokenType.EXCEPT, "MOD": TokenType.MOD, "NE": TokenType.NEQ, "NOT=": TokenType.NEQ, @@ -51,6 +54,7 @@ class Teradata(Dialect): "SEL": TokenType.SELECT, "ST_GEOMETRY": TokenType.GEOMETRY, "TOP": TokenType.TOP, + "UPD": TokenType.UPDATE, } # Teradata does not support % as a modulo operator @@ -181,6 +185,13 @@ class Teradata(Dialect): exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"): + # We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>) + expression.to.pop() + + return super().cast_sql(expression, safe_prefix=safe_prefix) + def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 867e4e4..a281297 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import ( generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, - move_insert_cte_sql, parse_date_delta, rename_func, timestrtotime_sql, @@ -158,8 +157,6 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: - expression = expression.copy() - this = expression.this distinct = expression.find(exp.Distinct) if distinct: @@ -246,6 +243,7 @@ class TSQL(Dialect): "MMM": "%b", "MM": "%m", "M": "%-m", + "dddd": "%A", "dd": "%d", "d": "%-d", "HH": "%H", @@ -596,6 +594,8 @@ class TSQL(Dialect): ALTER_TABLE_ADD_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" COMPUTED_COLUMN_WITH_TYPE = False + SUPPORTS_NESTED_CTES = False + CTE_RECURSIVE_KEYWORD_REQUIRED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -622,7 +622,6 @@ class TSQL(Dialect): exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), - exp.Insert: move_insert_cte_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, @@ -685,7 +684,6 @@ class TSQL(Dialect): return sql def create_sql(self, expression: exp.Create) -> str: - expression = expression.copy() kind = self.sql(expression, "kind").upper() exists = expression.args.pop("exists", None) sql = super().create_sql(expression) @@ -714,7 +712,7 @@ class TSQL(Dialect): elif expression.args.get("replace"): sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) - return sql + return self.prepend_ctes(expression, sql) def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" |