summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-16 05:45:52 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-16 05:45:52 +0000
commit3d48060515ba25b4c49d975a520ee0682327d1b7 (patch)
treee8730f509026e866d77c459f74a384505425363a /sqlglot
parentReleasing debian version 21.0.2-1. (diff)
downloadsqlglot-3d48060515ba25b4c49d975a520ee0682327d1b7.tar.xz
sqlglot-3d48060515ba25b4c49d975a520ee0682327d1b7.zip
Merging upstream version 21.1.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dataframe/sql/session.py10
-rw-r--r--sqlglot/dialects/bigquery.py50
-rw-r--r--sqlglot/dialects/clickhouse.py1
-rw-r--r--sqlglot/dialects/dialect.py29
-rw-r--r--sqlglot/dialects/drill.py4
-rw-r--r--sqlglot/dialects/hive.py44
-rw-r--r--sqlglot/dialects/mysql.py1
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/postgres.py2
-rw-r--r--sqlglot/dialects/presto.py2
-rw-r--r--sqlglot/dialects/redshift.py7
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/spark.py28
-rw-r--r--sqlglot/dialects/spark2.py29
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/dialects/tableau.py6
-rw-r--r--sqlglot/dialects/teradata.py1
-rw-r--r--sqlglot/dialects/tsql.py1
-rw-r--r--sqlglot/expressions.py70
-rw-r--r--sqlglot/generator.py3
-rw-r--r--sqlglot/helper.py30
-rw-r--r--sqlglot/lineage.py8
-rw-r--r--sqlglot/optimizer/annotate_types.py5
-rw-r--r--sqlglot/optimizer/canonicalize.py10
-rw-r--r--sqlglot/optimizer/qualify_columns.py50
-rw-r--r--sqlglot/parser.py52
-rw-r--r--sqlglot/schema.py4
-rw-r--r--sqlglot/tokens.py2
-rw-r--r--sqlglot/transforms.py94
30 files changed, 394 insertions, 155 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 29e7c55..133979a 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -148,7 +148,7 @@ def atanh(col: ColumnOrName) -> Column:
def cbrt(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "CBRT")
+ return Column.invoke_expression_over_column(col, expression.Cbrt)
def ceil(col: ColumnOrName) -> Column:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index f518ac2..bfc022b 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -70,12 +70,10 @@ class SparkSession:
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
data_expressions = [
- exp.Tuple(
- expressions=list(
- map(
- lambda x: F.lit(x).expression,
- row if not isinstance(row, dict) else row.values(),
- )
+ exp.tuple_(
+ *map(
+ lambda x: F.lit(x).expression,
+ row if not isinstance(row, dict) else row.values(),
)
)
for row in data
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 9068235..c0191b2 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -39,24 +39,31 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
alias = expression.args.get("alias")
- structs = [
- exp.Struct(
+ return self.unnest_sql(
+ exp.Unnest(
expressions=[
- exp.alias_(value, column_name)
- for value, column_name in zip(
- t.expressions,
- (
- alias.columns
- if alias and alias.columns
- else (f"_c{i}" for i in range(len(t.expressions)))
+ exp.array(
+ *(
+ exp.Struct(
+ expressions=[
+ exp.alias_(value, column_name)
+ for value, column_name in zip(
+ t.expressions,
+ (
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(t.expressions)))
+ ),
+ )
+ ]
+ )
+ for t in expression.find_all(exp.Tuple)
),
+ copy=False,
)
]
)
- for t in expression.find_all(exp.Tuple)
- ]
-
- return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
+ )
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
@@ -161,12 +168,18 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
return expression
-def _parse_timestamp(args: t.List) -> exp.StrToTime:
+def _parse_parse_timestamp(args: t.List) -> exp.StrToTime:
this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
this.set("zone", seq_get(args, 2))
return this
+def _parse_timestamp(args: t.List) -> exp.Timestamp:
+ timestamp = exp.Timestamp.from_arg_list(args)
+ timestamp.set("with_tz", True)
+ return timestamp
+
+
def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
expr_type = exp.DateFromParts if len(args) == 3 else exp.Date
return expr_type.from_arg_list(args)
@@ -318,6 +331,7 @@ class BigQuery(Dialect):
"TIMESTAMP": TokenType.TIMESTAMPTZ,
}
KEYWORDS.pop("DIV")
+ KEYWORDS.pop("VALUES")
class Parser(parser.Parser):
PREFIXED_PIVOT_COLUMNS = True
@@ -348,7 +362,7 @@ class BigQuery(Dialect):
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
- "PARSE_TIMESTAMP": _parse_timestamp,
+ "PARSE_TIMESTAMP": _parse_parse_timestamp,
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@@ -367,6 +381,7 @@ class BigQuery(Dialect):
"TIME": _parse_time,
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
+ "TIMESTAMP": _parse_timestamp,
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
@@ -395,11 +410,6 @@ class BigQuery(Dialect):
TokenType.TABLE,
}
- ID_VAR_TOKENS = {
- *parser.Parser.ID_VAR_TOKENS,
- TokenType.VALUES,
- }
-
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"NOT DETERMINISTIC": lambda self: self.expression(
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 1ec15c5..d7be64c 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -93,6 +93,7 @@ class ClickHouse(Dialect):
"IPV6": TokenType.IPV6,
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
+ "SYSTEM": TokenType.COMMAND,
}
SINGLE_TOKENS = {
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 6e2d190..0440a99 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -654,28 +654,6 @@ def time_format(
return _time_format
-def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
- """
- In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
- PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
- columns are removed from the create statement.
- """
- has_schema = isinstance(expression.this, exp.Schema)
- is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
-
- if has_schema and is_partitionable:
- prop = expression.find(exp.PartitionedByProperty)
- if prop and prop.this and not isinstance(prop.this, exp.Schema):
- schema = expression.this
- columns = {v.name.upper() for v in prop.this.expressions}
- partitions = [col for col in schema.expressions if col.name.upper() in columns]
- schema.set("expressions", [e for e in schema.expressions if e not in partitions])
- prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
- expression.set("this", schema)
-
- return self.create_sql(expression)
-
-
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
@@ -742,7 +720,10 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
if not expression.expression:
- return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
+ return self.sql(exp.cast(expression.this, to=target_type))
if expression.text("expression").lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
@@ -750,7 +731,7 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
zone=expression.expression,
)
)
- return self.function_fallback_sql(expression)
+ return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: t.List) -> exp.Expression:
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index be23355..409e260 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -5,7 +5,6 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
- create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
no_trycast_sql,
@@ -13,6 +12,7 @@ from sqlglot.dialects.dialect import (
str_position_sql,
timestrtotime_sql,
)
+from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
@@ -125,7 +125,7 @@ class Drill(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
- exp.Create: create_with_partitions_sql,
+ exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 6337ffd..b1540bb 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -9,7 +9,6 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
- create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
if_sql,
@@ -32,6 +31,12 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
var_map_sql,
)
+from sqlglot.transforms import (
+ remove_unique_constraints,
+ ctas_with_tmp_tables_to_create_tmp_view,
+ preprocess,
+ move_schema_columns_to_partitioned_by,
+)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@@ -55,30 +60,6 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _create_sql(self, expression: exp.Create) -> str:
- # remove UNIQUE column constraints
- for constraint in expression.find_all(exp.UniqueColumnConstraint):
- if constraint.parent:
- constraint.parent.pop()
-
- properties = expression.args.get("properties")
- temporary = any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- )
-
- # CTAS with temp tables map to CREATE TEMPORARY VIEW
- kind = expression.args["kind"]
- if kind.upper() == "TABLE" and temporary:
- if expression.expression:
- return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
- else:
- # CREATE TEMPORARY TABLE may require storage provider
- expression = self.temporary_storage_provider(expression)
-
- return create_with_partitions_sql(self, expression)
-
-
def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
if isinstance(expression, exp.TsOrDsAdd) and not expression.unit:
return self.func("DATE_ADD", expression.this, expression.expression)
@@ -285,6 +266,7 @@ class Hive(Dialect):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
+ VALUES_FOLLOWED_BY_PAREN = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -518,7 +500,13 @@ class Hive(Dialect):
"" if e.args.get("allow_null") else "NOT NULL"
),
exp.VarMap: var_map_sql,
- exp.Create: _create_sql,
+ exp.Create: preprocess(
+ [
+ remove_unique_constraints,
+ ctas_with_tmp_tables_to_create_tmp_view,
+ move_schema_columns_to_partitioned_by,
+ ]
+ ),
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
@@ -581,10 +569,6 @@ class Hive(Dialect):
return super()._jsonpathkey_sql(expression)
- def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
- # Hive has no temporary storage provider (there are hive settings though)
- return expression
-
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 661ef7d..97c891d 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -445,6 +445,7 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
STRING_ALIASES = True
+ VALUES_FOLLOWED_BY_PAREN = False
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 0c0d750..de693b9 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -88,6 +88,7 @@ class Oracle(Dialect):
class Parser(parser.Parser):
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
+ VALUES_FOLLOWED_BY_PAREN = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 68e2c6d..126261e 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -244,6 +244,8 @@ class Postgres(Dialect):
"@@": TokenType.DAT,
"@>": TokenType.AT_GT,
"<@": TokenType.LT_AT,
+ "|/": TokenType.PIPE_SLASH,
+ "||/": TokenType.DPIPE_SLASH,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 609103e..1e0e7e9 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -225,6 +225,8 @@ class Presto(Dialect):
}
class Parser(parser.Parser):
+ VALUES_FOLLOWED_BY_PAREN = False
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARBITRARY": exp.AnyValue.from_arg_list,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index a64c1d4..135ffc6 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -136,11 +136,11 @@ class Redshift(Postgres):
refs.add(
(
this.args["from"] if i == 0 else this.args["joins"][i - 1]
- ).alias_or_name.lower()
+ ).this.alias.lower()
)
- table = join.this
- if isinstance(table, exp.Table):
+ table = join.this
+ if isinstance(table, exp.Table) and not join.args.get("on"):
if table.parts[0].name.lower() in refs:
table.replace(table.to_column())
return this
@@ -158,6 +158,7 @@ class Redshift(Postgres):
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
}
+ KEYWORDS.pop("VALUES")
# Redshift allows # to appear as a table identifier prefix
SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 37f9761..b4275ea 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -477,6 +477,8 @@ class Snowflake(Dialect):
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
+ "USERS": _show_parser("USERS"),
+ "TERSE USERS": _show_parser("USERS"),
}
STAGED_FILE_SINGLE_TOKENS = {
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 44bd12d..c662ab5 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -5,8 +5,14 @@ import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.hive import _parse_ignore_nulls
-from sqlglot.dialects.spark2 import Spark2
+from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
+from sqlglot.transforms import (
+ ctas_with_tmp_tables_to_create_tmp_view,
+ remove_unique_constraints,
+ preprocess,
+ move_partitioned_by_to_schema_columns,
+)
def _parse_datediff(args: t.List) -> exp.Expression:
@@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression:
)
+def _normalize_partition(e: exp.Expression) -> exp.Expression:
+ """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
+ if isinstance(e, str):
+ return exp.to_identifier(e)
+ if isinstance(e, exp.Literal):
+ return exp.to_identifier(e.name)
+ return e
+
+
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
@@ -72,6 +87,17 @@ class Spark(Spark2):
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
+ exp.Create: preprocess(
+ [
+ remove_unique_constraints,
+ lambda e: ctas_with_tmp_tables_to_create_tmp_view(
+ e, temporary_storage_provider
+ ),
+ move_partitioned_by_to_schema_columns,
+ ]
+ ),
+ exp.PartitionedByProperty: lambda self,
+ e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 9378d99..fa55b51 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
+from sqlglot.transforms import (
+ preprocess,
+ remove_unique_constraints,
+ ctas_with_tmp_tables_to_create_tmp_view,
+ move_schema_columns_to_partitioned_by,
+)
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
@@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
+ # spark2, spark, Databricks require a storage provider for temporary tables
+ provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
+ expression.args["properties"].append("expressions", provider)
+ return expression
+
+
class Spark2(Hive):
class Parser(Hive.Parser):
TRIM_PATTERN_FIRST = True
@@ -121,7 +134,6 @@ class Spark2(Hive):
),
zone=seq_get(args, 1),
),
- "IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
@@ -193,6 +205,15 @@ class Spark2(Hive):
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
+ exp.Create: preprocess(
+ [
+ remove_unique_constraints,
+ lambda e: ctas_with_tmp_tables_to_create_tmp_view(
+ e, temporary_storage_provider
+ ),
+ move_schema_columns_to_partitioned_by,
+ ]
+ ),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -251,12 +272,6 @@ class Spark2(Hive):
return self.func("STRUCT", *args)
- def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
- # spark2, spark, Databricks require a storage provider for temporary tables
- provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
- expression.args["properties"].append("expressions", provider)
- return expression
-
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index b292c81..6596c5b 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -132,6 +132,7 @@ class SQLite(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
+ exp.If: rename_func("IIF"),
exp.ILike: no_ilike_sql,
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 3795045..e8ff249 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -1,10 +1,14 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, transforms
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
+ class Tokenizer(tokens.Tokenizer):
+ IDENTIFIERS = [("[", "]")]
+ QUOTES = ["'", '"']
+
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 7f9a11a..5b30cd4 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -74,6 +74,7 @@ class Teradata(Dialect):
class Parser(parser.Parser):
TABLESAMPLE_CSV = True
+ VALUES_FOLLOWED_BY_PAREN = False
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 70ea97e..85b2e12 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -457,7 +457,6 @@ class TSQL(Dialect):
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"HASHBYTES": _parse_hashbytes,
- "IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
"JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 11ebbaf..8ef750e 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1090,6 +1090,11 @@ class Create(DDL):
"clone": False,
}
+ @property
+ def kind(self) -> t.Optional[str]:
+ kind = self.args.get("kind")
+ return kind and kind.upper()
+
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
@@ -4626,6 +4631,11 @@ class CountIf(AggFunc):
_sql_names = ["COUNT_IF", "COUNTIF"]
+# cube root
+class Cbrt(Func):
+ pass
+
+
class CurrentDate(Func):
arg_types = {"this": False}
@@ -4728,7 +4738,7 @@ class Extract(Func):
class Timestamp(Func):
- arg_types = {"this": False, "expression": False}
+ arg_types = {"this": False, "expression": False, "with_tz": False}
class TimestampAdd(Func, TimeUnit):
@@ -4833,7 +4843,7 @@ class Posexplode(Explode):
pass
-class PosexplodeOuter(Posexplode):
+class PosexplodeOuter(Posexplode, ExplodeOuter):
pass
@@ -4868,6 +4878,7 @@ class Xor(Connector, Func):
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
+ _sql_names = ["IF", "IIF"]
class Nullif(Func):
@@ -6883,6 +6894,7 @@ def replace_tables(
table = to_table(
new_name,
**{k: v for k, v in node.args.items() if k not in TABLE_PARTS},
+ dialect=dialect,
)
table.add_comments([original])
return table
@@ -7072,6 +7084,60 @@ def cast_unless(
return cast(expr, to, **opts)
+def array(
+ *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
+) -> Array:
+ """
+ Returns an array.
+
+ Examples:
+ >>> array(1, 'x').sql()
+ 'ARRAY(1, x)'
+
+ Args:
+ expressions: the expressions to add to the array.
+ copy: whether or not to copy the argument expressions.
+ dialect: the source dialect.
+ kwargs: the kwargs used to instantiate the function of interest.
+
+ Returns:
+ An array expression.
+ """
+ return Array(
+ expressions=[
+ maybe_parse(expression, copy=copy, dialect=dialect, **kwargs)
+ for expression in expressions
+ ]
+ )
+
+
+def tuple_(
+ *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
+) -> Tuple:
+ """
+ Returns an tuple.
+
+ Examples:
+ >>> tuple_(1, 'x').sql()
+ '(1, x)'
+
+ Args:
+ expressions: the expressions to add to the tuple.
+ copy: whether or not to copy the argument expressions.
+ dialect: the source dialect.
+ kwargs: the kwargs used to instantiate the function of interest.
+
+ Returns:
+ A tuple expression.
+ """
+ return Tuple(
+ expressions=[
+ maybe_parse(expression, copy=copy, dialect=dialect, **kwargs)
+ for expression in expressions
+ ]
+ )
+
+
def true() -> Boolean:
"""
Returns a true Boolean expression.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 318d782..4ff5a0e 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -124,6 +124,7 @@ class Generator(metaclass=_Generator):
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
+ exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda self, e: "TRANSIENT",
@@ -3360,7 +3361,7 @@ class Generator(metaclass=_Generator):
return self.sql(arg)
cond_for_null = arg.is_(exp.null())
- return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
+ return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.array(arg, copy=False)))
def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
this = expression.this
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 9799fe2..35a4586 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -6,7 +6,7 @@ import logging
import re
import sys
import typing as t
-from collections.abc import Collection
+from collections.abc import Collection, Set
from contextlib import contextmanager
from copy import copy
from enum import Enum
@@ -496,3 +496,31 @@ DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
return expression is not None and expression.name.lower() in DATE_UNITS
+
+
+K = t.TypeVar("K")
+V = t.TypeVar("V")
+
+
+class SingleValuedMapping(t.Mapping[K, V]):
+ """
+ Mapping where all keys return the same value.
+
+ This rigamarole is meant to avoid copying keys, which was originally intended
+ as an optimization while qualifying columns for tables with lots of columns.
+ """
+
+ def __init__(self, keys: t.Collection[K], value: V):
+ self._keys = keys if isinstance(keys, Set) else set(keys)
+ self._value = value
+
+ def __getitem__(self, key: K) -> V:
+ if key in self._keys:
+ return self._value
+ raise KeyError(key)
+
+ def __len__(self) -> int:
+ return len(self._keys)
+
+ def __iter__(self) -> t.Iterator[K]:
+ return iter(self._keys)
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index bdd1d14..f10fbb9 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -153,7 +153,7 @@ def lineage(
raise ValueError(f"Could not find {column} in {scope.expression}")
for s in scope.union_scopes:
- to_node(index, scope=s, upstream=upstream)
+ to_node(index, scope=s, upstream=upstream, alias=alias)
return upstream
@@ -209,7 +209,11 @@ def lineage(
if isinstance(source, Scope):
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
- c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
+ c.name,
+ scope=source,
+ scope_name=table,
+ upstream=node,
+ alias=aliases.get(table) or alias,
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index cb9312c..ce274bb 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -204,7 +204,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
- exp.Timestamp,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
@@ -276,6 +275,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
+ exp.Timestamp: lambda self, e: self._annotate_with_type(
+ e,
+ exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
+ ),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index faf18c6..0aa8134 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -38,7 +38,12 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
- return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
+ if not node.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ node = annotate_types(node)
+ return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
+
return node
@@ -76,9 +81,8 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.Cast)
- and expression.to.type
and expression.this.type
- and expression.to.type.this == expression.this.type.this
+ and expression.to.this == expression.this.type.this
):
return expression.this
return expression
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 1656727..5c27bc3 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -6,7 +6,7 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
-from sqlglot.helper import seq_get
+from sqlglot.helper import seq_get, SingleValuedMapping
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -586,8 +586,8 @@ class Resolver:
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
- self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
- self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
+ self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
+ self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
@@ -640,7 +640,7 @@ class Resolver:
}
return self._all_columns
- def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
+ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
"""Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
@@ -662,10 +662,15 @@ class Resolver:
else:
column_aliases = []
- # If the source's columns are aliased, their aliases shadow the corresponding column names
- return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
+ if column_aliases:
+ # If the source's columns are aliased, their aliases shadow the corresponding column names.
+ # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
+ return [
+ alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
+ ]
+ return columns
- def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
+ def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:
self._source_columns = {
source_name: self.get_source_columns(source_name)
@@ -676,8 +681,8 @@ class Resolver:
return self._source_columns
def _get_unambiguous_columns(
- self, source_columns: t.Dict[str, t.List[str]]
- ) -> t.Dict[str, str]:
+ self, source_columns: t.Dict[str, t.Sequence[str]]
+ ) -> t.Mapping[str, str]:
"""
Find all the unambiguous columns in sources.
@@ -693,12 +698,17 @@ class Resolver:
source_columns_pairs = list(source_columns.items())
first_table, first_columns = source_columns_pairs[0]
- unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
+
+ if len(source_columns_pairs) == 1:
+ # Performance optimization - avoid copying first_columns if there is only one table.
+ return SingleValuedMapping(first_columns, first_table)
+
+ unambiguous_columns = {col: first_table for col in first_columns}
all_columns = set(unambiguous_columns)
for table, columns in source_columns_pairs[1:]:
- unique = self._find_unique_columns(columns)
- ambiguous = set(all_columns).intersection(unique)
+ unique = set(columns)
+ ambiguous = all_columns.intersection(unique)
all_columns.update(columns)
for column in ambiguous:
@@ -707,19 +717,3 @@ class Resolver:
unambiguous_columns[column] = table
return unambiguous_columns
-
- @staticmethod
- def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
- """
- Find the unique columns in a list of columns.
-
- Example:
- >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
- ['a', 'c']
-
- This is necessary because duplicate column names are ambiguous.
- """
- counts: t.Dict[str, int] = {}
- for column in columns:
- counts[column] = counts.get(column, 0) + 1
- return {column for column, count in counts.items() if count == 1}
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index dfa3024..25c5789 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -29,8 +29,8 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
values.append(args[i + 1])
return exp.VarMap(
- keys=exp.Array(expressions=keys),
- values=exp.Array(expressions=values),
+ keys=exp.array(*keys, copy=False),
+ values=exp.array(*values, copy=False),
)
@@ -638,6 +638,8 @@ class Parser(metaclass=_Parser):
TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()),
TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()),
TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()),
+ TokenType.PIPE_SLASH: lambda self: self.expression(exp.Sqrt, this=self._parse_unary()),
+ TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()),
}
PRIMARY_PARSERS = {
@@ -1000,9 +1002,13 @@ class Parser(metaclass=_Parser):
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
- # parses no parenthesis if statements as commands
+ # Parses no parenthesis if statements as commands
NO_PAREN_IF_COMMANDS = True
+ # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
+ # If this is True and '(' is not found, the keyword will be treated as an identifier
+ VALUES_FOLLOWED_BY_PAREN = True
+
__slots__ = (
"error_level",
"error_message_context",
@@ -2058,7 +2064,7 @@ class Parser(metaclass=_Parser):
partition=self._parse_partition(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
and self._parse_conjunction(),
- expression=self._parse_ddl_select(),
+ expression=self._parse_derived_table_values() or self._parse_ddl_select(),
conflict=self._parse_on_conflict(),
returning=returning or self._parse_returning(),
overwrite=overwrite,
@@ -2267,8 +2273,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
- # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
- # https://prestodb.io/docs/current/sql/values.html
+ # In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows.
return self.expression(exp.Tuple, expressions=[self._parse_expression()])
def _parse_projections(self) -> t.List[exp.Expression]:
@@ -2367,12 +2372,8 @@ class Parser(metaclass=_Parser):
# We return early here so that the UNION isn't attached to the subquery by the
# following call to _parse_set_operations, but instead becomes the parent node
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
- elif self._match(TokenType.VALUES):
- this = self.expression(
- exp.Values,
- expressions=self._parse_csv(self._parse_value),
- alias=self._parse_table_alias(),
- )
+ elif self._match(TokenType.VALUES, advance=False):
+ this = self._parse_derived_table_values()
elif from_:
this = exp.select("*").from_(from_.this, copy=False)
else:
@@ -2969,7 +2970,7 @@ class Parser(metaclass=_Parser):
def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
- if not is_derived and not self._match(TokenType.VALUES):
+ if not is_derived and not self._match_text_seq("VALUES"):
return None
expressions = self._parse_csv(self._parse_value)
@@ -3655,8 +3656,15 @@ class Parser(metaclass=_Parser):
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
- # Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals
- while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
+ # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
+ while True:
+ index = self._index
+ self._match(TokenType.PLUS)
+
+ if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
+ self._retreat(index)
+ break
+
interval = self.expression( # type: ignore
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
)
@@ -3872,9 +3880,15 @@ class Parser(metaclass=_Parser):
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
- if isinstance(this, exp.Identifier):
- this = self.expression(exp.Column, this=this)
- return this
+ if (
+ not this
+ and self._match(TokenType.VALUES, advance=False)
+ and self.VALUES_FOLLOWED_BY_PAREN
+ and (not self._next or self._next.token_type != TokenType.L_PAREN)
+ ):
+ this = self._parse_id_var()
+
+ return self.expression(exp.Column, this=this) if isinstance(this, exp.Identifier) else this
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@@ -5511,7 +5525,7 @@ class Parser(metaclass=_Parser):
then = self.expression(
exp.Insert,
this=self._parse_value(),
- expression=self._match(TokenType.VALUES) and self._parse_value(),
+ expression=self._match_text_seq("VALUES") and self._parse_value(),
)
elif self._match(TokenType.UPDATE):
expressions = self._parse_star()
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 1fd4025..dbd0caa 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -49,7 +49,7 @@ class Schema(abc.ABC):
only_visible: bool = False,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
- ) -> t.List[str]:
+ ) -> t.Sequence[str]:
"""
Get the column names for a table.
@@ -60,7 +60,7 @@ class Schema(abc.ABC):
normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
- The list of column names.
+ The sequence of column names.
"""
@abc.abstractmethod
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index b064957..2cfcfa6 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -57,6 +57,8 @@ class TokenType(AutoName):
AMP = auto()
DPIPE = auto()
PIPE = auto()
+ PIPE_SLASH = auto()
+ DPIPE_SLASH = auto()
CARET = auto()
TILDA = auto()
ARROW = auto()
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index f13569f..4777609 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -213,6 +213,19 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
is_posexplode = isinstance(explode, exp.Posexplode)
explode_arg = explode.this
+ if isinstance(explode, exp.ExplodeOuter):
+ bracket = explode_arg[0]
+ bracket.set("safe", True)
+ bracket.set("offset", True)
+ explode_arg = exp.func(
+ "IF",
+ exp.func(
+ "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
+ ).eq(0),
+ exp.array(bracket, copy=False),
+ explode_arg,
+ )
+
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)
@@ -466,6 +479,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
+ assert isinstance(expression, exp.Create)
+ for constraint in expression.find_all(exp.UniqueColumnConstraint):
+ if constraint.parent:
+ constraint.parent.pop()
+
+ return expression
+
+
+def ctas_with_tmp_tables_to_create_tmp_view(
+ expression: exp.Expression,
+ tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
+) -> exp.Expression:
+ assert isinstance(expression, exp.Create)
+ properties = expression.args.get("properties")
+ temporary = any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ )
+
+ # CTAS with temp tables map to CREATE TEMPORARY VIEW
+ if expression.kind == "TABLE" and temporary:
+ if expression.expression:
+ return exp.Create(
+ kind="TEMPORARY VIEW",
+ this=expression.this,
+ expression=expression.expression,
+ )
+ return tmp_storage_provider(expression)
+
+ return expression
+
+
+def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
+ """
+ In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
+ PARTITIONED BY value is an array of column names, they are transformed into a schema.
+ The corresponding columns are removed from the create statement.
+ """
+ assert isinstance(expression, exp.Create)
+ has_schema = isinstance(expression.this, exp.Schema)
+ is_partitionable = expression.kind in {"TABLE", "VIEW"}
+
+ if has_schema and is_partitionable:
+ prop = expression.find(exp.PartitionedByProperty)
+ if prop and prop.this and not isinstance(prop.this, exp.Schema):
+ schema = expression.this
+ columns = {v.name.upper() for v in prop.this.expressions}
+ partitions = [col for col in schema.expressions if col.name.upper() in columns]
+ schema.set("expressions", [e for e in schema.expressions if e not in partitions])
+ prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
+ expression.set("this", schema)
+
+ return expression
+
+
+def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
+
+ Currently, SQLGlot uses the DATASOURCE format for Spark 3.
+ """
+ assert isinstance(expression, exp.Create)
+ prop = expression.find(exp.PartitionedByProperty)
+ if (
+ prop
+ and prop.this
+ and isinstance(prop.this, exp.Schema)
+ and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
+ ):
+ prop_this = exp.Tuple(
+ expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
+ )
+ schema = expression.this
+ for e in prop.this.expressions:
+ schema.append("expressions", e)
+ prop.set("this", prop_this)
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: