summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-11-01 05:12:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-11-01 05:12:42 +0000
commitc51a9844b869fd7cd69e5cc7658d34f61a865185 (patch)
tree55706c65ce7e19626aabf7ff4dde0e1a51b739db /sqlglot
parentReleasing debian version 18.17.0-1. (diff)
downloadsqlglot-c51a9844b869fd7cd69e5cc7658d34f61a865185.tar.xz
sqlglot-c51a9844b869fd7cd69e5cc7658d34f61a865185.zip
Merging upstream version 19.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/bigquery.py2
-rw-r--r--sqlglot/dialects/clickhouse.py1
-rw-r--r--sqlglot/dialects/databricks.py9
-rw-r--r--sqlglot/dialects/dialect.py29
-rw-r--r--sqlglot/dialects/drill.py6
-rw-r--r--sqlglot/dialects/duckdb.py6
-rw-r--r--sqlglot/dialects/hive.py17
-rw-r--r--sqlglot/dialects/mysql.py38
-rw-r--r--sqlglot/dialects/postgres.py12
-rw-r--r--sqlglot/dialects/presto.py9
-rw-r--r--sqlglot/dialects/redshift.py18
-rw-r--r--sqlglot/dialects/snowflake.py7
-rw-r--r--sqlglot/dialects/spark.py2
-rw-r--r--sqlglot/dialects/spark2.py6
-rw-r--r--sqlglot/dialects/sqlite.py2
-rw-r--r--sqlglot/dialects/teradata.py11
-rw-r--r--sqlglot/dialects/tsql.py10
-rw-r--r--sqlglot/expressions.py86
-rw-r--r--sqlglot/generator.py125
-rw-r--r--sqlglot/helper.py4
-rw-r--r--sqlglot/optimizer/normalize.py27
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py4
-rw-r--r--sqlglot/optimizer/qualify_tables.py2
-rw-r--r--sqlglot/optimizer/simplify.py77
-rw-r--r--sqlglot/parser.py89
-rw-r--r--sqlglot/schema.py14
-rw-r--r--sqlglot/transforms.py70
28 files changed, 477 insertions, 208 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index be10f3d..35feaad 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -158,6 +158,6 @@ def transpile(
"""
write = (read if write is None else write) if identity else write
return [
- Dialect.get_or_raise(write)().generate(expression, **opts)
+ Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else ""
for expression in parse(sql, read, error_level=error_level)
]
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 51baba2..fc9a3ae 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -69,7 +69,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
- expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
@@ -699,6 +698,5 @@ class BigQuery(Dialect):
def version_sql(self, expression: exp.Version) -> str:
if expression.name == "TIMESTAMP":
- expression = expression.copy()
expression.set("this", "SYSTEM_TIME")
return super().version_sql(expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 30f728c..394a922 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -461,7 +461,6 @@ class ClickHouse(Dialect):
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
- expression = expression.copy()
return self.func(
"CONCAT",
*[
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 314a821..b777db0 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -35,7 +35,7 @@ class Databricks(Spark):
exp.DatetimeSub: lambda self, e: self.func(
"TIMESTAMPADD",
e.text("unit"),
- exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
+ exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
e.this,
),
exp.DatetimeDiff: lambda self, e: self.func(
@@ -63,21 +63,14 @@ class Databricks(Spark):
and kind.this in exp.DataType.INTEGER_TYPES
):
# only BIGINT generated identity constraints are supported
- expression = expression.copy()
expression.set("kind", exp.DataType.build("bigint"))
return super().columndef_sql(expression, sep)
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
- expression = expression.copy()
expression.set("this", True) # trigger ALWAYS in super class
return super().generatedasidentitycolumnconstraint_sql(expression)
class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []
-
- SINGLE_TOKENS = {
- **Spark.Tokenizer.SINGLE_TOKENS,
- "$": TokenType.PARAMETER,
- }
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 739e8d7..21e7889 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -315,11 +315,14 @@ class Dialect(metaclass=_Dialect):
) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
- def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
- return self.generator(**opts).generate(expression)
+ def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
+ return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> t.List[str]:
- return [self.generate(expression, **opts) for expression in self.parse(sql)]
+ return [
+ self.generate(expression, copy=False, **opts) if expression else ""
+ for expression in self.parse(sql)
+ ]
def tokenize(self, sql: str) -> t.List[Token]:
return self.tokenizer.tokenize(sql)
@@ -380,9 +383,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
- exp.Like(
- this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
- )
+ exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
)
@@ -518,7 +519,6 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
if has_schema and is_partitionable:
- expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
@@ -583,7 +583,7 @@ def date_add_interval_sql(
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
- interval = exp.Interval(this=expression.expression.copy(), unit=unit)
+ interval = exp.Interval(this=expression.expression, unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
@@ -621,7 +621,6 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
- expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this, start=exp.Literal.number(1), length=expression.expression
@@ -630,7 +629,6 @@ def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
- expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this,
@@ -675,7 +673,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this.expressions[0]
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
- return self.func("sum", exp.func("if", cond.copy(), 1, 0))
+ return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql(self: Generator, expression: exp.Trim) -> str:
@@ -716,12 +714,10 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
- expression = expression.copy()
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
- expression = expression.copy()
delim, *rest_args = expression.expressions
return self.sql(
reduce(
@@ -809,13 +805,6 @@ def isnull_to_is_null(args: t.List) -> exp.Expression:
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
-def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
- if expression.expression.args.get("with"):
- expression = expression.copy()
- expression.set("with", expression.expression.args["with"].pop())
- return self.insert_sql(expression)
-
-
def generatedasidentitycolumnconstraint_sql(
self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 8b2e708..42453fd 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -20,7 +20,9 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
- return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
+ return (
+ f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+ )
return func
@@ -147,7 +149,7 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})",
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 287e03a..d8d9f90 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -36,14 +36,14 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
+ return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
- return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
+ return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
# BigQuery -> DuckDB conversion for the DATE function
@@ -365,7 +365,7 @@ class DuckDB(Dialect):
multiplier = 90
if multiplier:
- return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})"
+ return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
return super().interval_sql(expression)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 7bff553..3b1c8de 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -53,8 +53,6 @@ DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _create_sql(self, expression: exp.Create) -> str:
- expression = expression.copy()
-
# remove UNIQUE column constraints
for constraint in expression.find_all(exp.UniqueColumnConstraint):
if constraint.parent:
@@ -88,7 +86,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
else:
- modified_increment = expression.expression.copy()
+ modified_increment = expression.expression
if multiplier != 1:
modified_increment = exp.Mul( # type: ignore
this=modified_increment, expression=exp.Literal.number(multiplier)
@@ -229,6 +227,11 @@ class Hive(Dialect):
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.PARAMETER,
+ }
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
@@ -408,6 +411,7 @@ class Hive(Dialect):
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
+ SUPPORTS_NESTED_CTES = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -521,7 +525,10 @@ class Hive(Dialect):
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
+ expression_sql = self.sql(expression, "expression")
+
parent = expression.parent
+ this = f"{this}:{expression_sql}" if expression_sql else this
if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem):
# We need to produce SET key = value instead of SET ${key} = value
@@ -530,8 +537,6 @@ class Hive(Dialect):
return f"${{{this}}}"
def schema_sql(self, expression: exp.Schema) -> str:
- expression = expression.copy()
-
for ordered in expression.find_all(exp.Ordered):
if ordered.args.get("desc") is False:
ordered.set("desc", None)
@@ -539,8 +544,6 @@ class Hive(Dialect):
return super().schema_sql(expression)
def constraint_sql(self, expression: exp.Constraint) -> str:
- expression = expression.copy()
-
for prop in list(expression.find_all(exp.Properties)):
prop.pop()
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 2185a85..c78aa9e 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
return f"STR_TO_DATE({concat}, '{date_format}')"
-def _str_to_date(args: t.List) -> exp.StrToDate:
- date_format = MySQL.format_time(seq_get(args, 1))
- return exp.StrToDate(this=seq_get(args, 0), format=date_format)
+# All specifiers for time parts (as opposed to date parts)
+# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format
+TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"}
+
+
+def _has_time_specifier(date_format: str) -> bool:
+ i = 0
+ length = len(date_format)
+
+ while i < length:
+ if date_format[i] == "%":
+ i += 1
+ if i < length and date_format[i] in TIME_SPECIFIERS:
+ return True
+ i += 1
+ return False
+
+
+def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime:
+ mysql_date_format = seq_get(args, 1)
+ date_format = MySQL.format_time(mysql_date_format)
+ this = seq_get(args, 0)
+
+ if mysql_date_format and _has_time_specifier(mysql_date_format.name):
+ return exp.StrToTime(this=this, format=date_format)
+
+ return exp.StrToDate(this=this, format=date_format)
def _str_to_date_sql(
@@ -93,7 +117,9 @@ def _date_add_sql(
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
- return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
+ return (
+ f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+ )
return func
@@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date(
args: t.Tuple[str, ...] = ("this",),
) -> t.Callable[[MySQL.Generator, exp.Func], str]:
def func(self: MySQL.Generator, expression: exp.Func) -> str:
- expression = expression.copy()
-
for arg_key in args:
arg = expression.args.get(arg_key)
if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
@@ -629,6 +653,7 @@ class MySQL(Dialect):
transforms.eliminate_distinct_on,
transforms.eliminate_semi_and_anti_joins,
transforms.eliminate_qualify,
+ transforms.eliminate_full_outer_join,
]
),
exp.StrPosition: strposition_to_locate_sql,
@@ -728,7 +753,6 @@ class MySQL(Dialect):
to = self.CAST_MAPPING.get(expression.to.this)
if to:
- expression = expression.copy()
expression.to.set("this", to)
return super().cast_sql(expression)
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 086b278..27c6851 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -43,8 +43,6 @@ DATE_DIFF_FACTOR = {
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
- expression = expression.copy()
-
this = self.sql(expression, "this")
unit = expression.args.get("unit")
@@ -96,7 +94,6 @@ def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
- expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
order = ""
@@ -119,7 +116,6 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
- expression = expression.copy()
expression.args["constraints"].remove(auto.parent)
kind = expression.args["kind"]
@@ -134,7 +130,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
- kind = expression.args["kind"]
+ kind = expression.args.get("kind")
+ if not kind:
+ return expression
if kind.this == exp.DataType.Type.SERIAL:
data_type = exp.DataType(this=exp.DataType.Type.INT)
@@ -146,7 +144,6 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
data_type = None
if data_type:
- expression = expression.copy()
expression.args["kind"].replace(data_type)
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
@@ -409,6 +406,7 @@ class Postgres(Dialect):
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
@@ -445,6 +443,7 @@ class Postgres(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@@ -452,7 +451,6 @@ class Postgres(Dialect):
def bracket_sql(self, expression: exp.Bracket) -> str:
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
if isinstance(expression.this, exp.Array):
- expression = expression.copy()
expression.set("this", exp.paren(expression.this, copy=False))
return super().bracket_sql(expression)
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index aac368c..ded3655 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -36,7 +36,6 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct)
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
- expression = expression.copy()
return self.sql(
exp.Join(
this=exp.Unnest(
@@ -72,7 +71,6 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
for schema in expression.parent.find_all(exp.Schema):
column_defs = schema.find_all(exp.ColumnDef)
if column_defs and isinstance(schema.parent, exp.Property):
- expression = expression.copy()
expression.expressions.extend(column_defs)
return self.schema_sql(expression)
@@ -407,12 +405,10 @@ class Presto(Dialect):
target_type = None
if target_type and target_type.is_type("timestamp"):
- to = target_type.copy()
-
if target_type is start.to:
- end = exp.cast(end, to)
+ end = exp.cast(end, target_type)
else:
- start = exp.cast(start, to)
+ start = exp.cast(start, target_type)
return self.func("SEQUENCE", start, end, step)
@@ -432,6 +428,5 @@ class Presto(Dialect):
kind = expression.args["kind"]
schema = expression.this
if kind == "VIEW" and schema.expressions:
- expression = expression.copy()
expression.this.set("expressions", None)
return super().create_sql(expression)
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index df70aa7..6c7ba35 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -27,6 +27,14 @@ def _parse_date_add(args: t.List) -> exp.DateAdd:
)
+def _parse_datediff(args: t.List) -> exp.DateDiff:
+ return exp.DateDiff(
+ this=exp.TsOrDsToDate(this=seq_get(args, 2)),
+ expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
+ unit=seq_get(args, 0),
+ )
+
+
class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
@@ -51,11 +59,9 @@ class Redshift(Postgres):
),
"DATEADD": _parse_date_add,
"DATE_ADD": _parse_date_add,
- "DATEDIFF": lambda args: exp.DateDiff(
- this=exp.TsOrDsToDate(this=seq_get(args, 2)),
- expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
- unit=seq_get(args, 0),
- ),
+ "DATEDIFF": _parse_datediff,
+ "DATE_DIFF": _parse_datediff,
+ "LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@@ -175,6 +181,7 @@ class Redshift(Postgres):
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
+ exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
@@ -207,7 +214,6 @@ class Redshift(Postgres):
`TEXT` to `VARCHAR`.
"""
if expression.is_type("text"):
- expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 07be65b..01f7512 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -32,7 +32,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
+def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -60,8 +60,8 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
# reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
if not isinstance(simplify_literals(first_arg, root=True), Literal):
- # case: <variant_expr>
- return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
+ # case: <variant_expr> or other expressions such as columns
+ return exp.TimeStrToTime.from_arg_list(args)
if first_arg.is_string:
if _check_int(first_arg.this):
@@ -560,7 +560,6 @@ class Snowflake(Dialect):
offset = expression.args.get("offset")
if offset:
if unnest_alias:
- expression = expression.copy()
unnest_alias.append("columns", offset.pop())
selects.append("index")
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 8461920..1abfce6 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -63,6 +63,8 @@ class Spark(Spark2):
return this
class Generator(Spark2.Generator):
+ SUPPORTS_NESTED_CTES = True
+
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 2fd4f4e..da84bd8 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import (
binary_from_function,
format_time_lambda,
is_parse_json,
- move_insert_cte_sql,
pivot_column_names,
rename_func,
trim_sql,
@@ -70,7 +69,9 @@ def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
alias = pivot.args["alias"].pop()
return exp.From(
this=expression.this.replace(
- exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
+ exp.select("*")
+ .from_(expression.this.copy(), copy=False)
+ .subquery(alias=alias, copy=False)
)
)
@@ -188,7 +189,6 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
- exp.Insert: move_insert_cte_sql,
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 1edfa9d..1fa730d 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -50,7 +50,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
else:
for column in defs.values():
auto_increment = None
- for constraint in column.constraints.copy():
+ for constraint in column.constraints:
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
break
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 152afa6..e8162c2 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -38,12 +38,15 @@ class Teradata(Dialect):
"^=": TokenType.NEQ,
"BYTEINT": TokenType.SMALLINT,
"COLLECT": TokenType.COMMAND,
+ "DEL": TokenType.DELETE,
+ "EQ": TokenType.EQ,
"GE": TokenType.GTE,
"GT": TokenType.GT,
"HELP": TokenType.COMMAND,
"INS": TokenType.INSERT,
"LE": TokenType.LTE,
"LT": TokenType.LT,
+ "MINUS": TokenType.EXCEPT,
"MOD": TokenType.MOD,
"NE": TokenType.NEQ,
"NOT=": TokenType.NEQ,
@@ -51,6 +54,7 @@ class Teradata(Dialect):
"SEL": TokenType.SELECT,
"ST_GEOMETRY": TokenType.GEOMETRY,
"TOP": TokenType.TOP,
+ "UPD": TokenType.UPDATE,
}
# Teradata does not support % as a modulo operator
@@ -181,6 +185,13 @@ class Teradata(Dialect):
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"):
+ # We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>)
+ expression.to.pop()
+
+ return super().cast_sql(expression, safe_prefix=safe_prefix)
+
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 867e4e4..a281297 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
- move_insert_cte_sql,
parse_date_delta,
rename_func,
timestrtotime_sql,
@@ -158,8 +157,6 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
- expression = expression.copy()
-
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
@@ -246,6 +243,7 @@ class TSQL(Dialect):
"MMM": "%b",
"MM": "%m",
"M": "%-m",
+ "dddd": "%A",
"dd": "%d",
"d": "%-d",
"HH": "%H",
@@ -596,6 +594,8 @@ class TSQL(Dialect):
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
COMPUTED_COLUMN_WITH_TYPE = False
+ SUPPORTS_NESTED_CTES = False
+ CTE_RECURSIVE_KEYWORD_REQUIRED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -622,7 +622,6 @@ class TSQL(Dialect):
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
- exp.Insert: move_insert_cte_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@@ -685,7 +684,6 @@ class TSQL(Dialect):
return sql
def create_sql(self, expression: exp.Create) -> str:
- expression = expression.copy()
kind = self.sql(expression, "kind").upper()
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)
@@ -714,7 +712,7 @@ class TSQL(Dialect):
elif expression.args.get("replace"):
sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
- return sql
+ return self.prepend_ctes(expression, sql)
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 5b012b1..99ebfb3 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -2145,6 +2145,22 @@ class PartitionedByProperty(Property):
arg_types = {"this": True}
+# https://www.postgresql.org/docs/current/sql-createtable.html
+class PartitionBoundSpec(Expression):
+ # this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...)
+ arg_types = {
+ "this": False,
+ "expression": False,
+ "from_expressions": False,
+ "to_expressions": False,
+ }
+
+
+class PartitionedOfProperty(Property):
+ # this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT
+ arg_types = {"this": True, "expression": True}
+
+
class RemoteWithConnectionModelProperty(Property):
arg_types = {"this": True}
@@ -2486,6 +2502,7 @@ class Table(Expression):
"format": False,
"pattern": False,
"index": False,
+ "ordinality": False,
}
@property
@@ -2649,11 +2666,7 @@ class Update(Expression):
class Values(UDTF):
- arg_types = {
- "expressions": True,
- "ordinality": False,
- "alias": False,
- }
+ arg_types = {"expressions": True, "alias": False}
class Var(Expression):
@@ -3501,7 +3514,7 @@ class Star(Expression):
class Parameter(Condition):
- arg_types = {"this": True, "wrapped": False}
+ arg_types = {"this": True, "expression": False}
class SessionParameter(Condition):
@@ -5036,7 +5049,7 @@ class FromBase(Func):
class Struct(Func):
- arg_types = {"expressions": True}
+ arg_types = {"expressions": False}
is_var_len_args = True
@@ -5171,7 +5184,7 @@ class Use(Expression):
class Merge(Expression):
- arg_types = {"this": True, "using": True, "on": True, "expressions": True}
+ arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False}
class When(Func):
@@ -5459,7 +5472,12 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
def union(
- left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+ left: ExpOrStr,
+ right: ExpOrStr,
+ distinct: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
) -> Union:
"""
Initializes a syntax tree from one UNION expression.
@@ -5475,19 +5493,25 @@ def union(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
The new Union instance.
"""
- left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
- right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
+ left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
+ right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Union(this=left, expression=right, distinct=distinct)
def intersect(
- left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+ left: ExpOrStr,
+ right: ExpOrStr,
+ distinct: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
) -> Intersect:
"""
Initializes a syntax tree from one INTERSECT expression.
@@ -5503,19 +5527,25 @@ def intersect(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
The new Intersect instance.
"""
- left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
- right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
+ left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
+ right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Intersect(this=left, expression=right, distinct=distinct)
def except_(
- left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+ left: ExpOrStr,
+ right: ExpOrStr,
+ distinct: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
) -> Except:
"""
Initializes a syntax tree from one EXCEPT expression.
@@ -5531,13 +5561,14 @@ def except_(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
The new Except instance.
"""
- left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
- right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
+ left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
+ right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Except(this=left, expression=right, distinct=distinct)
@@ -5861,7 +5892,7 @@ def to_identifier(name, quoted=None, copy=True):
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
- copy: Whether or not to copy a passed in Identefier node.
+ copy: Whether or not to copy name if it's an Identifier.
Returns:
The identifier ast node.
@@ -5882,6 +5913,25 @@ def to_identifier(name, quoted=None, copy=True):
return identifier
+def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
+ """
+ Parses a given string into an identifier.
+
+ Args:
+ name: The name to parse into an identifier.
+ dialect: The dialect to parse against.
+
+ Returns:
+ The identifier ast node.
+ """
+ try:
+ expression = maybe_parse(name, dialect=dialect, into=Identifier)
+ except ParseError:
+ expression = to_identifier(name)
+
+ return expression
+
+
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 0d6778a..4916cf8 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -230,6 +230,12 @@ class Generator:
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
+ # Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
+ SUPPORTS_NESTED_CTES = True
+
+ # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
+ CTE_RECURSIVE_KEYWORD_REQUIRED = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -304,6 +310,7 @@ class Generator:
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
+ exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
@@ -407,7 +414,6 @@ class Generator:
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
- "_cache",
)
def __init__(
@@ -447,30 +453,38 @@ class Generator:
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
)
- self._cache: t.Optional[t.Dict[int, str]] = None
- def generate(
- self,
- expression: t.Optional[exp.Expression],
- cache: t.Optional[t.Dict[int, str]] = None,
- ) -> str:
+ def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Args:
expression: The syntax tree.
- cache: An optional sql string cache. This leverages the hash of an Expression
- which can be slow to compute, so only use it if you set _hash on each node.
+ copy: Whether or not to copy the expression. The generator performs mutations so
+ it is safer to copy.
Returns:
The SQL string corresponding to `expression`.
"""
- if cache is not None:
- self._cache = cache
+ if copy:
+ expression = expression.copy()
+
+ # Some dialects only support CTEs at the top level expression, so we need to bubble up nested
+ # CTEs to that level in order to produce a syntactically valid expression. This transformation
+ # happens here to minimize code duplication, since many expressions support CTEs.
+ if (
+ not self.SUPPORTS_NESTED_CTES
+ and isinstance(expression, exp.Expression)
+ and not expression.parent
+ and "with" in expression.arg_types
+ and any(node.parent is not expression for node in expression.find_all(exp.With))
+ ):
+ from sqlglot.transforms import move_ctes_to_top_level
+
+ expression = move_ctes_to_top_level(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
- self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@@ -595,12 +609,6 @@ class Generator:
return self.sql(value)
return ""
- if self._cache is not None:
- expression_id = hash(expression)
-
- if expression_id in self._cache:
- return self._cache[expression_id]
-
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@@ -621,11 +629,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
-
- if self._cache is not None:
- self._cache[expression_id] = sql
- return sql
+ return self.maybe_comment(sql, expression) if self.comments and comment else sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@@ -879,7 +883,11 @@ class Generator:
def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
- recursive = "RECURSIVE " if expression.args.get("recursive") else ""
+ recursive = (
+ "RECURSIVE "
+ if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive")
+ else ""
+ )
return f"WITH {recursive}{sql}"
@@ -1022,7 +1030,7 @@ class Generator:
where = self.sql(expression, "expression").strip()
return f"{this} FILTER({where})"
- agg = expression.this.copy()
+ agg = expression.this
agg_arg = agg.this
cond = expression.expression.this
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
@@ -1088,9 +1096,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
- with_properties.append(p.copy())
+ with_properties.append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA:
- root_properties.append(p.copy())
+ root_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
@@ -1124,7 +1132,7 @@ class Generator:
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc != exp.Properties.Location.UNSUPPORTED:
- properties_locs[p_loc].append(p.copy())
+ properties_locs[p_loc].append(p)
else:
self.unsupported(f"Unsupported property {p.key}")
@@ -1238,6 +1246,29 @@ class Generator:
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
+ def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
+ if isinstance(expression.this, list):
+ return f"IN ({self.expressions(expression, key='this', flat=True)})"
+ if expression.this:
+ modulus = self.sql(expression, "this")
+ remainder = self.sql(expression, "expression")
+ return f"WITH (MODULUS {modulus}, REMAINDER {remainder})"
+
+ from_expressions = self.expressions(expression, key="from_expressions", flat=True)
+ to_expressions = self.expressions(expression, key="to_expressions", flat=True)
+ return f"FROM ({from_expressions}) TO ({to_expressions})"
+
+ def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str:
+ this = self.sql(expression, "this")
+
+ for_values_or_default = expression.expression
+ if isinstance(for_values_or_default, exp.PartitionBoundSpec):
+ for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}"
+ else:
+ for_values_or_default = " DEFAULT"
+
+ return f"PARTITION OF {this}{for_values_or_default}"
+
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
@@ -1385,7 +1416,12 @@ class Generator:
index = self.sql(expression, "index")
index = f" AT {index}" if index else ""
- return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"
+ ordinality = expression.args.get("ordinality") or ""
+ if ordinality:
+ ordinality = f" WITH ORDINALITY{alias}"
+ alias = ""
+
+ return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@@ -1489,7 +1525,6 @@ class Generator:
return f"{values} AS {alias}" if alias else values
# Converts `VALUES...` expression into a series of select unions.
- expression = expression.copy()
alias_node = expression.args.get("alias")
column_names = alias_node and alias_node.columns
@@ -1972,8 +2007,7 @@ class Generator:
if self.UNNEST_WITH_ORDINALITY:
if alias and isinstance(offset, exp.Expression):
- alias = alias.copy()
- alias.append("columns", offset.copy())
+ alias.append("columns", offset)
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
@@ -2138,7 +2172,6 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
- expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
@@ -2367,7 +2400,9 @@ class Generator:
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
format_sql = self.sql(expression, "format")
format_sql = f" FORMAT {format_sql}" if format_sql else ""
- return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
+ to_sql = self.sql(expression, "to")
+ to_sql = f" {to_sql}" if to_sql else ""
+ return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@@ -2510,7 +2545,7 @@ class Generator:
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
- this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
+ this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
@@ -2779,7 +2814,6 @@ class Generator:
hints = table.args.get("hints")
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
- table = table.copy()
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
this = self.sql(table)
@@ -2787,7 +2821,9 @@ class Generator:
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ")
- return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
+ return self.prepend_ctes(
+ expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
+ )
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
@@ -2896,12 +2932,12 @@ class Generator:
case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
- expression.args["true"].copy(),
+ expression.args["true"],
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
- case.else_(else_cond.copy(), copy=False)
+ case.else_(else_cond, copy=False)
return self.sql(case)
@@ -2931,15 +2967,6 @@ class Generator:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
- expression = simplify(expression.copy())
+ expression = simplify(expression)
return expression
-
-
-def cached_generator(
- cache: t.Optional[t.Dict[int, str]] = None
-) -> t.Callable[[exp.Expression], str]:
- """Returns a cached generator."""
- cache = {} if cache is None else cache
- generator = Generator(normalize=True, identify="safe")
- return lambda e: generator.generate(e, cache)
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 74b61e3..ee41557 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -184,9 +184,7 @@ def apply_index_offset(
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
- expression = simplify(
- exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
- )
+ expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
return [expression]
return expressions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 8d82b2d..6df36af 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -4,7 +4,6 @@ import logging
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
@@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- generate = cached_generator()
-
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
@@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
try:
node = node.replace(
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
)
except OptimizeError as e:
logger.info(e)
@@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance, generate):
+def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func, generate)
- return _distribute(b, a, from_func, to_func, generate)
+ return _distribute(a, b, from_func, to_func)
+ return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func, generate)
+ return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func, generate)
+ return _distribute(a, b, from_func, to_func)
return expression
-def _distribute(a, b, from_func, to_func, generate):
+def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- uniq_sort(flatten(from_func(c, b.left)), generate),
- uniq_sort(flatten(from_func(c, b.right)), generate),
+ uniq_sort(flatten(from_func(c, b.left))),
+ uniq_sort(flatten(from_func(c, b.right))),
copy=False,
),
)
else:
a = to_func(
- uniq_sort(flatten(from_func(a, b.left)), generate),
- uniq_sort(flatten(from_func(a, b.right)), generate),
+ uniq_sort(flatten(from_func(a, b.left))),
+ uniq_sort(flatten(from_func(a, b.right))),
copy=False,
)
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index ecea6a0..154256e 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parse_one
+from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
The transformed expression.
"""
if isinstance(expression, str):
- expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
+ expression = exp.parse_identifier(expression, dialect=dialect)
dialect = Dialect.get_or_raise(dialect)
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 68aebdb..3a43e8f 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -62,7 +62,7 @@ def qualify_tables(
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
- if not source.args.get("catalog"):
+ if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 30de75b..af03332 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -7,8 +7,7 @@ from decimal import Decimal
import sqlglot
from sqlglot import exp
-from sqlglot.generator import cached_generator
-from sqlglot.helper import first, merge_ranges, while_changing
+from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
@@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
- generate = cached_generator()
-
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
@@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
# Pre-order transformations
node = expression
node = rewrite_between(node)
- node = uniq_sort(node, generate, root)
+ node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
@@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
return expression
-def uniq_sort(expression, generate, root=True):
+def uniq_sort(expression, root=True):
"""
Uniq and sort a connector.
@@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {generate(e): e for e in flattened}
+ deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression
+
+
+def gen(expression: t.Any) -> str:
+ """Simple pseudo sql generator for quickly generating sortable and uniq strings.
+
+ Sorting and deduping sql is a necessary step for optimization. Calling the actual
+ generator is expensive so we have a bare minimum sql generator here.
+ """
+ if expression is None:
+ return "_"
+ if is_iterable(expression):
+ return ",".join(gen(e) for e in expression)
+ if not isinstance(expression, exp.Expression):
+ return str(expression)
+
+ etype = type(expression)
+ if etype in GEN_MAP:
+ return GEN_MAP[etype](expression)
+ return f"{expression.key} {gen(expression.args.values())}"
+
+
+GEN_MAP = {
+ exp.Add: lambda e: _binary(e, "+"),
+ exp.And: lambda e: _binary(e, "AND"),
+ exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
+ exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
+ exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
+ exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
+ exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
+ exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
+ exp.Div: lambda e: _binary(e, "/"),
+ exp.Dot: lambda e: _binary(e, "."),
+ exp.DPipe: lambda e: _binary(e, "||"),
+ exp.SafeDPipe: lambda e: _binary(e, "||"),
+ exp.EQ: lambda e: _binary(e, "="),
+ exp.GT: lambda e: _binary(e, ">"),
+ exp.GTE: lambda e: _binary(e, ">="),
+ exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
+ exp.ILike: lambda e: _binary(e, "ILIKE"),
+ exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
+ exp.Is: lambda e: _binary(e, "IS"),
+ exp.Like: lambda e: _binary(e, "LIKE"),
+ exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
+ exp.LT: lambda e: _binary(e, "<"),
+ exp.LTE: lambda e: _binary(e, "<="),
+ exp.Mod: lambda e: _binary(e, "%"),
+ exp.Mul: lambda e: _binary(e, "*"),
+ exp.Neg: lambda e: _unary(e, "-"),
+ exp.NEQ: lambda e: _binary(e, "<>"),
+ exp.Not: lambda e: _unary(e, "NOT"),
+ exp.Null: lambda e: "NULL",
+ exp.Or: lambda e: _binary(e, "OR"),
+ exp.Paren: lambda e: f"({gen(e.this)})",
+ exp.Sub: lambda e: _binary(e, "-"),
+ exp.Subquery: lambda e: f"({gen(e.args.values())})",
+ exp.Table: lambda e: gen(e.args.values()),
+ exp.Var: lambda e: e.name,
+}
+
+
+def _binary(e: exp.Binary, op: str) -> str:
+ return f"{gen(e.left)} {op} {gen(e.right)}"
+
+
+def _unary(e: exp.Unary, op: str) -> str:
+ return f"{op} {gen(e.this)}"
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index b7f91ab..1dab600 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -674,6 +674,7 @@ class Parser(metaclass=_Parser):
"ON": lambda self: self._parse_on_property(),
"ORDER BY": lambda self: self._parse_order(skip_order_token=True),
"OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()),
+ "PARTITION": lambda self: self._parse_partitioned_of(),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
@@ -1743,6 +1744,58 @@ class Parser(metaclass=_Parser):
return self._parse_csv(self._parse_conjunction)
return []
+ def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec:
+ def _parse_partition_bound_expr() -> t.Optional[exp.Expression]:
+ if self._match_text_seq("MINVALUE"):
+ return exp.var("MINVALUE")
+ if self._match_text_seq("MAXVALUE"):
+ return exp.var("MAXVALUE")
+ return self._parse_bitwise()
+
+ this: t.Optional[exp.Expression | t.List[exp.Expression]] = None
+ expression = None
+ from_expressions = None
+ to_expressions = None
+
+ if self._match(TokenType.IN):
+ this = self._parse_wrapped_csv(self._parse_bitwise)
+ elif self._match(TokenType.FROM):
+ from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
+ self._match_text_seq("TO")
+ to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
+ elif self._match_text_seq("WITH", "(", "MODULUS"):
+ this = self._parse_number()
+ self._match_text_seq(",", "REMAINDER")
+ expression = self._parse_number()
+ self._match_r_paren()
+ else:
+ self.raise_error("Failed to parse partition bound spec.")
+
+ return self.expression(
+ exp.PartitionBoundSpec,
+ this=this,
+ expression=expression,
+ from_expressions=from_expressions,
+ to_expressions=to_expressions,
+ )
+
+ # https://www.postgresql.org/docs/current/sql-createtable.html
+ def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]:
+ if not self._match_text_seq("OF"):
+ self._retreat(self._index - 1)
+ return None
+
+ this = self._parse_table(schema=True)
+
+ if self._match(TokenType.DEFAULT):
+ expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT")
+ elif self._match_text_seq("FOR", "VALUES"):
+ expression = self._parse_partition_bound_spec()
+ else:
+ self.raise_error("Expecting either DEFAULT or FOR VALUES clause.")
+
+ return self.expression(exp.PartitionedOfProperty, this=this, expression=expression)
+
def _parse_partitioned_by(self) -> exp.PartitionedByProperty:
self._match(TokenType.EQ)
return self.expression(
@@ -2682,6 +2735,10 @@ class Parser(metaclass=_Parser):
for join in iter(self._parse_join, None):
this.append("joins", join)
+ if self._match_pair(TokenType.WITH, TokenType.ORDINALITY):
+ this.set("ordinality", True)
+ this.set("alias", self._parse_table_alias())
+
return this
def _parse_version(self) -> t.Optional[exp.Version]:
@@ -4189,17 +4246,12 @@ class Parser(metaclass=_Parser):
fmt = None
to = self._parse_types()
- if not to:
- self.raise_error("Expected TYPE after CAST")
- elif isinstance(to, exp.Identifier):
- to = exp.DataType.build(to.name, udt=True)
- elif to.this == exp.DataType.Type.CHAR:
- if self._match(TokenType.CHARACTER_SET):
- to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
- elif self._match(TokenType.FORMAT):
+ if self._match(TokenType.FORMAT):
fmt_string = self._parse_string()
fmt = self._parse_at_time_zone(fmt_string)
+ if not to:
+ to = exp.DataType.build(exp.DataType.Type.UNKNOWN)
if to.this in exp.DataType.TEMPORAL_TYPES:
this = self.expression(
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
@@ -4215,8 +4267,14 @@ class Parser(metaclass=_Parser):
if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime):
this.set("zone", fmt.args["zone"])
-
return this
+ elif not to:
+ self.raise_error("Expected TYPE after CAST")
+ elif isinstance(to, exp.Identifier):
+ to = exp.DataType.build(to.name, udt=True)
+ elif to.this == exp.DataType.Type.CHAR:
+ if self._match(TokenType.CHARACTER_SET):
+ to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
return self.expression(
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
@@ -4789,10 +4847,17 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _parse_parameter(self) -> exp.Parameter:
- wrapped = self._match(TokenType.L_BRACE)
- this = self._parse_var() or self._parse_identifier() or self._parse_primary()
+ def _parse_parameter_part() -> t.Optional[exp.Expression]:
+ return (
+ self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True)
+ )
+
+ self._match(TokenType.L_BRACE)
+ this = _parse_parameter_part()
+ expression = self._match(TokenType.COLON) and _parse_parameter_part()
self._match(TokenType.R_BRACE)
- return self.expression(exp.Parameter, this=this, wrapped=wrapped)
+
+ return self.expression(exp.Parameter, this=this, expression=expression)
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PLACEHOLDER_PARSERS):
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 778378c..acf9bc4 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -3,10 +3,9 @@ from __future__ import annotations
import abc
import typing as t
-import sqlglot
from sqlglot import expressions as exp
from sqlglot.dialects.dialect import Dialect
-from sqlglot.errors import ParseError, SchemaError
+from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import TrieResult, in_trie, new_trie
@@ -448,19 +447,16 @@ class MappingSchema(AbstractMappingSchema, Schema):
def normalize_name(
- name: str | exp.Identifier,
+ identifier: str | exp.Identifier,
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = True,
) -> str:
- try:
- identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
- except ParseError:
- return name if isinstance(name, str) else name.name
+ if isinstance(identifier, str):
+ identifier = exp.parse_identifier(identifier, dialect=dialect)
- name = identifier.name
if not normalize:
- return name
+ return identifier.name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index e0fd68f..445fda6 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -67,7 +67,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
order = expression.args.get("order")
if order:
- window.set("order", order.pop().copy())
+ window.set("order", order.pop())
else:
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
@@ -75,9 +75,9 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
expression.select(window, copy=False)
return (
- exp.select(*outer_selects)
- .from_(expression.subquery("_t"))
- .where(exp.column(row_number).eq(1))
+ exp.select(*outer_selects, copy=False)
+ .from_(expression.subquery("_t", copy=False), copy=False)
+ .where(exp.column(row_number).eq(1), copy=False)
)
return expression
@@ -120,7 +120,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
- return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
+ return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
+ qualify_filters, copy=False
+ )
return expression
@@ -189,7 +191,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
)
# we use list here because expression.selects is mutated inside the loop
- for select in expression.selects.copy():
+ for select in list(expression.selects):
explode = select.find(exp.Explode)
if explode:
@@ -374,6 +376,60 @@ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
return expression
+def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
+ """
+ Converts a query with a FULL OUTER join to a union of identical queries that
+ use LEFT/RIGHT OUTER joins instead. This transformation currently only works
+ for queries that have a single FULL OUTER join.
+ """
+ if isinstance(expression, exp.Select):
+ full_outer_joins = [
+ (index, join)
+ for index, join in enumerate(expression.args.get("joins") or [])
+ if join.side == "FULL" and join.kind == "OUTER"
+ ]
+
+ if len(full_outer_joins) == 1:
+ expression_copy = expression.copy()
+ index, full_outer_join = full_outer_joins[0]
+ full_outer_join.set("side", "left")
+ expression_copy.args["joins"][index].set("side", "right")
+
+ return exp.union(expression, expression_copy, copy=False)
+
+ return expression
+
+
+def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
+ """
+ Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
+ defined at the top-level, so for example queries like:
+
+ SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
+
+ are invalid in those dialects. This transformation can be used to ensure all CTEs are
+ moved to the top level so that the final SQL code is valid from a syntax standpoint.
+
+ TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
+ """
+ top_level_with = expression.args.get("with")
+ for node in expression.find_all(exp.With):
+ if node.parent is expression:
+ continue
+
+ inner_with = node.pop()
+ if not top_level_with:
+ top_level_with = inner_with
+ expression.set("with", top_level_with)
+ else:
+ if inner_with.recursive:
+ top_level_with.set("recursive", True)
+
+ top_level_with.expressions.extend(inner_with.expressions)
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
@@ -392,7 +448,7 @@ def preprocess(
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
- expression = transforms[0](expression.copy())
+ expression = transforms[0](expression)
for t in transforms[1:]:
expression = t(expression)