summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
commit42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch)
tree5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot
parentReleasing debian version 21.1.2-1. (diff)
downloadsqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz
sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py6
-rw-r--r--sqlglot/dataframe/sql/dataframe.py6
-rw-r--r--sqlglot/dataframe/sql/functions.py4
-rw-r--r--sqlglot/dialects/bigquery.py102
-rw-r--r--sqlglot/dialects/clickhouse.py9
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py2
-rw-r--r--sqlglot/dialects/doris.py3
-rw-r--r--sqlglot/dialects/drill.py3
-rw-r--r--sqlglot/dialects/duckdb.py64
-rw-r--r--sqlglot/dialects/hive.py28
-rw-r--r--sqlglot/dialects/mysql.py21
-rw-r--r--sqlglot/dialects/oracle.py57
-rw-r--r--sqlglot/dialects/postgres.py26
-rw-r--r--sqlglot/dialects/presto.py29
-rw-r--r--sqlglot/dialects/redshift.py35
-rw-r--r--sqlglot/dialects/snowflake.py40
-rw-r--r--sqlglot/dialects/spark2.py11
-rw-r--r--sqlglot/dialects/sqlite.py16
-rw-r--r--sqlglot/dialects/tsql.py189
-rw-r--r--sqlglot/diff.py41
-rw-r--r--sqlglot/executor/context.py2
-rw-r--r--sqlglot/executor/python.py7
-rw-r--r--sqlglot/expressions.py502
-rw-r--r--sqlglot/generator.py201
-rw-r--r--sqlglot/helper.py12
-rw-r--r--sqlglot/lineage.py237
-rw-r--r--sqlglot/optimizer/annotate_types.py41
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py6
-rw-r--r--sqlglot/optimizer/qualify_columns.py21
-rw-r--r--sqlglot/optimizer/qualify_tables.py4
-rw-r--r--sqlglot/optimizer/scope.py12
-rw-r--r--sqlglot/optimizer/simplify.py16
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py14
-rw-r--r--sqlglot/parser.py340
-rw-r--r--sqlglot/tokens.py21
-rw-r--r--sqlglot/transforms.py18
37 files changed, 1388 insertions, 760 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 2207a28..e30232c 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -88,13 +88,11 @@ def parse(
@t.overload
-def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
- ...
+def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
@t.overload
-def parse_one(sql: str, **opts) -> Expression:
- ...
+def parse_one(sql: str, **opts) -> Expression: ...
def parse_one(
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 7e3f07b..0bacbf9 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -140,12 +140,10 @@ class DataFrame:
return cte, name
@t.overload
- def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
- ...
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
@t.overload
- def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
- ...
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 308b639..db5201f 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -210,7 +210,7 @@ def sec(col: ColumnOrName) -> Column:
def signum(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SIGNUM")
+ return Column.invoke_expression_over_column(col, expression.Sign)
def sin(col: ColumnOrName) -> Column:
@@ -592,7 +592,7 @@ def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
+ return Column.invoke_expression_over_column(start, expression.AddMonths, expression=months)
def months_between(
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index f867617..5bfc3ea 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -42,7 +42,10 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
alias = expression.args.get("alias")
for tup in expression.find_all(exp.Tuple):
field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
- expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)]
+ expressions = [
+ exp.PropertyEQ(this=exp.to_identifier(name), expression=fld)
+ for name, fld in zip(field_aliases, tup.expressions)
+ ]
structs.append(exp.Struct(expressions=expressions))
return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
@@ -111,6 +114,8 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
}
for grouped in group.expressions:
+ if grouped.is_int:
+ continue
alias = aliases.get(grouped)
if alias:
grouped.replace(exp.column(alias))
@@ -226,8 +231,11 @@ class BigQuery(Dialect):
# bigquery udfs are case sensitive
NORMALIZE_FUNCTIONS = False
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time
TIME_MAPPING = {
"%D": "%m/%d/%y",
+ "%E*S": "%S.%f",
+ "%E6S": "%S.%f",
}
ESCAPE_SEQUENCES = {
@@ -266,14 +274,20 @@ class BigQuery(Dialect):
while isinstance(parent, exp.Dot):
parent = parent.parent
- # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
- # The following check is essentially a heuristic to detect tables based on whether or
- # not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive.
- if (
- not isinstance(parent, exp.UserDefinedFunction)
- and not (isinstance(parent, exp.Table) and parent.db)
- and not expression.meta.get("is_table")
- ):
+ # In BigQuery, CTEs are case-insensitive, but UDF and table names are case-sensitive
+ # by default. The following check uses a heuristic to detect tables based on whether
+ # they are qualified. This should generally be correct, because tables in BigQuery
+ # must be qualified with at least a dataset, unless @@dataset_id is set.
+ case_sensitive = (
+ isinstance(parent, exp.UserDefinedFunction)
+ or (
+ isinstance(parent, exp.Table)
+ and parent.db
+ and (parent.meta.get("quoted_table") or not parent.meta.get("maybe_column"))
+ )
+ or expression.meta.get("is_table")
+ )
+ if not case_sensitive:
expression.set("this", expression.this.lower())
return expression
@@ -302,6 +316,7 @@ class BigQuery(Dialect):
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"DECLARE": TokenType.COMMAND,
+ "ELSEIF": TokenType.COMMAND,
"EXCEPTION": TokenType.COMMAND,
"FLOAT64": TokenType.DOUBLE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
@@ -315,8 +330,8 @@ class BigQuery(Dialect):
class Parser(parser.Parser):
PREFIXED_PIVOT_COLUMNS = True
-
LOG_DEFAULTS_TO_LN = True
+ SUPPORTS_IMPLICIT_UNNEST = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -410,6 +425,7 @@ class BigQuery(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
+ TokenType.ELSE: lambda self: self._parse_as_command(self._prev),
TokenType.END: lambda self: self._parse_as_command(self._prev),
TokenType.FOR: lambda self: self._parse_for_in(),
}
@@ -433,8 +449,11 @@ class BigQuery(Dialect):
if isinstance(this, exp.Identifier):
table_name = this.name
while self._match(TokenType.DASH, advance=False) and self._next:
- self._advance(2)
- table_name += f"-{self._prev.text}"
+ text = ""
+ while self._curr and self._curr.token_type != TokenType.DOT:
+ self._advance()
+ text += self._prev.text
+ table_name += text
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
elif isinstance(this, exp.Literal):
@@ -448,12 +467,28 @@ class BigQuery(Dialect):
return this
def _parse_table_parts(
- self, schema: bool = False, is_db_reference: bool = False
+ self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
) -> exp.Table:
- table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
+ table = super()._parse_table_parts(
+ schema=schema, is_db_reference=is_db_reference, wildcard=True
+ )
+
+ # proj-1.db.tbl -- `1.` is tokenized as a float so we need to unravel it here
+ if not table.catalog:
+ if table.db:
+ parts = table.db.split(".")
+ if len(parts) == 2 and not table.args["db"].quoted:
+ table.set("catalog", exp.Identifier(this=parts[0]))
+ table.set("db", exp.Identifier(this=parts[1]))
+ else:
+ parts = table.name.split(".")
+ if len(parts) == 2 and not table.this.quoted:
+ table.set("db", exp.Identifier(this=parts[0]))
+ table.set("this", exp.Identifier(this=parts[1]))
+
if isinstance(table.this, exp.Identifier) and "." in table.name:
catalog, db, this, *rest = (
- t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
+ t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True))
for x in split_num_words(table.name, ".", 3)
)
@@ -461,16 +496,15 @@ class BigQuery(Dialect):
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
table = exp.Table(this=this, db=db, catalog=catalog)
+ table.meta["quoted_table"] = True
return table
@t.overload
- def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
- ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
- def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
- ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
@@ -532,6 +566,7 @@ class BigQuery(Dialect):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
CAN_IMPLEMENT_ARRAY_ANY = True
+ NAMED_PLACEHOLDER_TOKEN = "@"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -762,22 +797,25 @@ class BigQuery(Dialect):
"within",
}
+ def table_parts(self, expression: exp.Table) -> str:
+ # Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so
+ # we need to make sure the correct quoting is used in each case.
+ #
+ # For example, if there is a CTE x that clashes with a schema name, then the former will
+ # return the table y in that schema, whereas the latter will return the CTE's y column:
+ #
+ # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x.y` -> cross join
+ # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x`.`y` -> implicit unnest
+ if expression.meta.get("quoted_table"):
+ table_parts = ".".join(p.name for p in expression.parts)
+ return self.sql(exp.Identifier(this=table_parts, quoted=True))
+
+ return super().table_parts(expression)
+
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression
return self.func("FORMAT_DATE", self.format_time(expression), this.this)
- def struct_sql(self, expression: exp.Struct) -> str:
- args = []
- for expr in expression.expressions:
- if isinstance(expr, self.KEY_VALUE_DEFINITIONS):
- arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}"
- else:
- arg = self.sql(expr)
-
- args.append(arg)
-
- return self.func("STRUCT", *args)
-
def eq_sql(self, expression: exp.EQ) -> str:
# Operands of = cannot be NULL in BigQuery
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
@@ -803,7 +841,7 @@ class BigQuery(Dialect):
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
- if isinstance(first_arg, exp.Subqueryable):
+ if isinstance(first_arg, exp.Query):
return f"ARRAY{self.wrap(self.sql(first_arg))}"
return inline_array_sql(self, expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 05d6a03..90167f6 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -68,7 +68,6 @@ class ClickHouse(Dialect):
"DATE32": TokenType.DATE32,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
- "ENUM": TokenType.ENUM,
"ENUM8": TokenType.ENUM8,
"ENUM16": TokenType.ENUM16,
"FINAL": TokenType.FINAL,
@@ -93,6 +92,7 @@ class ClickHouse(Dialect):
"AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
"SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
"SYSTEM": TokenType.COMMAND,
+ "PREWHERE": TokenType.PREWHERE,
}
SINGLE_TOKENS = {
@@ -129,6 +129,7 @@ class ClickHouse(Dialect):
"MAP": parser.build_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"RANDCANONICAL": exp.Rand.from_arg_list,
+ "TUPLE": exp.Struct.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
"XOR": lambda args: exp.Xor(expressions=args),
}
@@ -390,7 +391,7 @@ class ClickHouse(Dialect):
return self.expression(
exp.CTE,
- this=self._parse_field(),
+ this=self._parse_conjunction(),
alias=self._parse_table_alias(),
scalar=True,
)
@@ -732,3 +733,7 @@ class ClickHouse(Dialect):
return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"
return super().createable_sql(expression, locations)
+
+ def prewhere_sql(self, expression: exp.PreWhere) -> str:
+ this = self.indent(self.sql(expression, "this"))
+ return f"{self.seg('PREWHERE')}{self.sep()}{this}"
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 96eff18..188b6a7 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -69,7 +69,7 @@ class Databricks(Spark):
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
- kind = expression.args.get("kind")
+ kind = expression.kind
if (
constraint
and isinstance(kind, exp.DataType)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index b0a78d2..599505c 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -443,7 +443,7 @@ class Dialect(metaclass=_Dialect):
identify: If set to `False`, the quotes will only be added if the identifier is deemed
"unsafe", with respect to its characters and this dialect's normalization strategy.
"""
- if isinstance(expression, exp.Identifier):
+ if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
name = expression.this
expression.set(
"quoted",
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 067a045..9a84848 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -21,6 +21,7 @@ class Doris(MySQL):
**MySQL.Parser.FUNCTIONS,
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_TRUNC": build_timestamp_trunc,
+ "MONTHS_ADD": exp.AddMonths.from_arg_list,
"REGEXP": exp.RegexpLike.from_arg_list,
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
@@ -41,6 +42,7 @@ class Doris(MySQL):
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
+ exp.AddMonths: rename_func("MONTHS_ADD"),
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
@@ -58,7 +60,6 @@ class Doris(MySQL):
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
- exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 4e699f5..c1f6afa 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -156,6 +156,3 @@ class Drill(Dialect):
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
-
- def normalize_func(self, name: str) -> str:
- return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 925c5ae..f74dc97 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -79,6 +79,21 @@ def _build_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
+def _build_generate_series(end_exclusive: bool = False) -> t.Callable[[t.List], exp.GenerateSeries]:
+ def _builder(args: t.List) -> exp.GenerateSeries:
+ # Check https://duckdb.org/docs/sql/functions/nested.html#range-functions
+ if len(args) == 1:
+ # DuckDB uses 0 as a default for the series' start when it's omitted
+ args.insert(0, exp.Literal.number("0"))
+
+ gen_series = exp.GenerateSeries.from_arg_list(args)
+ gen_series.set("is_end_exclusive", end_exclusive)
+
+ return gen_series
+
+ return _builder
+
+
def _build_make_timestamp(args: t.List) -> exp.Expression:
if len(args) == 1:
return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
@@ -95,13 +110,13 @@ def _build_make_timestamp(args: t.List) -> exp.Expression:
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
args: t.List[str] = []
- for expr in expression.expressions:
- if isinstance(expr, exp.Alias):
- key = expr.alias
- value = expr.this
- else:
- key = expr.name or expr.this.name
+ for i, expr in enumerate(expression.expressions):
+ if isinstance(expr, exp.PropertyEQ):
+ key = expr.name
value = expr.expression
+ else:
+ key = f"_{i}"
+ value = expr
args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}")
@@ -148,13 +163,6 @@ def _rename_unless_within_group(
)
-def _build_struct_pack(args: t.List) -> exp.Struct:
- args_with_columns_as_identifiers = [
- exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
- ]
- return exp.Struct.from_arg_list(args_with_columns_as_identifiers)
-
-
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
@@ -189,6 +197,7 @@ class DuckDB(Dialect):
"CHARACTER VARYING": TokenType.TEXT,
"EXCLUDE": TokenType.EXCEPT,
"LOGICAL": TokenType.BOOLEAN,
+ "ONLY": TokenType.ONLY,
"PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
@@ -213,6 +222,8 @@ class DuckDB(Dialect):
TokenType.TILDA: exp.RegexpLike,
}
+ FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "STRUCT_PACK"}
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
@@ -261,12 +272,14 @@ class DuckDB(Dialect):
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"),
- "STRUCT_PACK": _build_struct_pack,
+ "STRUCT_PACK": exp.Struct.from_arg_list,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
"XOR": binary_from_function(exp.BitwiseXor),
+ "GENERATE_SERIES": _build_generate_series(),
+ "RANGE": _build_generate_series(end_exclusive=True),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
@@ -313,6 +326,8 @@ class DuckDB(Dialect):
return pivot_column_names(aggregations, dialect="duckdb")
class Generator(generator.Generator):
+ PARAMETER_TOKEN = "$"
+ NAMED_PLACEHOLDER_TOKEN = "$"
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
@@ -535,5 +550,22 @@ class DuckDB(Dialect):
return self.sql(expression, "this")
return super().columndef_sql(expression, sep)
- def placeholder_sql(self, expression: exp.Placeholder) -> str:
- return f"${expression.name}" if expression.name else "?"
+ def join_sql(self, expression: exp.Join) -> str:
+ if (
+ expression.side == "LEFT"
+ and not expression.args.get("on")
+ and isinstance(expression.this, exp.Unnest)
+ ):
+ # Some dialects support `LEFT JOIN UNNEST(...)` without an explicit ON clause
+ # DuckDB doesn't, but we can just add a dummy ON clause that is always true
+ return super().join_sql(expression.on(exp.true()))
+
+ return super().join_sql(expression)
+
+ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
+ # GENERATE_SERIES(a, b) -> [a, b], RANGE(a, b) -> [a, b)
+ if expression.args.get("is_end_exclusive"):
+ expression.set("is_end_exclusive", None)
+ return rename_func("RANGE")(self, expression)
+
+ return super().generateseries_sql(expression)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 43211dc..55a9254 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -140,6 +140,15 @@ def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
+def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
+ timestamp = self.sql(expression, "this")
+ scale = expression.args.get("scale")
+ if scale in (None, exp.UnixToTime.SECONDS):
+ return rename_func("FROM_UNIXTIME")(self, expression)
+
+ return f"FROM_UNIXTIME({timestamp} / POW(10, {scale}))"
+
+
def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@@ -536,7 +545,7 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, time_format("hive")(self, e)
),
- exp.UnixToTime: rename_func("FROM_UNIXTIME"),
+ exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
@@ -609,9 +618,8 @@ class Hive(Dialect):
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
def datatype_sql(self, expression: exp.DataType) -> str:
- if (
- expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
- and not expression.expressions
+ if expression.this in self.PARAMETERIZABLE_TEXT_TYPES and (
+ not expression.expressions or expression.expressions[0].name == "MAX"
):
expression = exp.DataType.build("text")
elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions:
@@ -631,3 +639,15 @@ class Hive(Dialect):
def version_sql(self, expression: exp.Version) -> str:
sql = super().version_sql(expression)
return sql.replace("FOR ", "", 1)
+
+ def struct_sql(self, expression: exp.Struct) -> str:
+ values = []
+
+ for i, e in enumerate(expression.expressions):
+ if isinstance(e, exp.PropertyEQ):
+ self.unsupported("Hive does not support named structs.")
+ values.append(e.expression)
+ else:
+ values.append(e)
+
+ return self.func("STRUCT", *values)
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index e549f62..6ebae1e 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -185,7 +185,6 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"CHARSET": TokenType.CHARACTER_SET,
- "ENUM": TokenType.ENUM,
"FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE,
"LOCK TABLES": TokenType.COMMAND,
@@ -391,6 +390,11 @@ class MySQL(Dialect):
"WARNINGS": _show_parser("WARNINGS"),
}
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ "LOCK": lambda self: self._parse_property_assignment(exp.LockProperty),
+ }
+
SET_PARSERS = {
**parser.Parser.SET_PARSERS,
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
@@ -416,16 +420,11 @@ class MySQL(Dialect):
"SPATIAL",
}
- PROFILE_TYPES = {
- "ALL",
- "BLOCK IO",
- "CONTEXT SWITCHES",
- "CPU",
- "IPC",
- "MEMORY",
- "PAGE FAULTS",
- "SOURCE",
- "SWAPS",
+ PROFILE_TYPES: parser.OPTIONS_TYPE = {
+ **dict.fromkeys(("ALL", "CPU", "IPC", "MEMORY", "SOURCE", "SWAPS"), tuple()),
+ "BLOCK": ("IO",),
+ "CONTEXT": ("SWITCHES",),
+ "PAGE": ("FAULTS",),
}
TYPE_TOKENS = {
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index fcb3aab..bccdad0 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -66,6 +66,26 @@ class Oracle(Dialect):
"FF6": "%f", # only 6 digits are supported in python formats
}
+ class Tokenizer(tokens.Tokenizer):
+ VAR_SINGLE_TOKENS = {"@", "$", "#"}
+
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ "(+)": TokenType.JOIN_MARKER,
+ "BINARY_DOUBLE": TokenType.DOUBLE,
+ "BINARY_FLOAT": TokenType.FLOAT,
+ "COLUMNS": TokenType.COLUMN,
+ "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
+ "MINUS": TokenType.EXCEPT,
+ "NVARCHAR2": TokenType.NVARCHAR,
+ "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
+ "SAMPLE": TokenType.TABLE_SAMPLE,
+ "START": TokenType.BEGIN,
+ "SYSDATE": TokenType.CURRENT_TIMESTAMP,
+ "TOP": TokenType.TOP,
+ "VARCHAR2": TokenType.VARCHAR,
+ }
+
class Parser(parser.Parser):
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
@@ -93,6 +113,21 @@ class Oracle(Dialect):
"XMLTABLE": lambda self: self._parse_xml_table(),
}
+ NO_PAREN_FUNCTION_PARSERS = {
+ **parser.Parser.NO_PAREN_FUNCTION_PARSERS,
+ "CONNECT_BY_ROOT": lambda self: self.expression(
+ exp.ConnectByRoot, this=self._parse_column()
+ ),
+ }
+
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ "GLOBAL": lambda self: self._match_text_seq("TEMPORARY")
+ and self.expression(exp.TemporaryProperty, this="GLOBAL"),
+ "PRIVATE": lambda self: self._match_text_seq("TEMPORARY")
+ and self.expression(exp.TemporaryProperty, this="PRIVATE"),
+ }
+
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
@@ -190,6 +225,7 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}",
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
@@ -207,6 +243,7 @@ class Oracle(Dialect):
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
+ exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
@@ -242,23 +279,3 @@ class Oracle(Dialect):
if len(expression.args.get("actions", [])) > 1:
return f"ADD ({actions})"
return f"ADD {actions}"
-
- class Tokenizer(tokens.Tokenizer):
- VAR_SINGLE_TOKENS = {"@", "$", "#"}
-
- KEYWORDS = {
- **tokens.Tokenizer.KEYWORDS,
- "(+)": TokenType.JOIN_MARKER,
- "BINARY_DOUBLE": TokenType.DOUBLE,
- "BINARY_FLOAT": TokenType.FLOAT,
- "COLUMNS": TokenType.COLUMN,
- "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
- "MINUS": TokenType.EXCEPT,
- "NVARCHAR2": TokenType.NVARCHAR,
- "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
- "SAMPLE": TokenType.TABLE_SAMPLE,
- "START": TokenType.BEGIN,
- "SYSDATE": TokenType.CURRENT_TIMESTAMP,
- "TOP": TokenType.TOP,
- "VARCHAR2": TokenType.VARCHAR,
- }
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c78f8a3..b53ae07 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -138,7 +138,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
- kind = expression.args.get("kind")
+ if not isinstance(expression, exp.ColumnDef):
+ return expression
+ kind = expression.kind
if not kind:
return expression
@@ -279,6 +281,7 @@ class Postgres(Dialect):
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
+ "ONLY": TokenType.ONLY,
"OPERATOR": TokenType.OPERATOR,
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
@@ -451,6 +454,7 @@ class Postgres(Dialect):
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
+ exp.ParseJSON: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.JSON)),
exp.JSONPathKey: json_path_key_only_name,
exp.JSONPathRoot: lambda *_: "",
exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this),
@@ -506,6 +510,26 @@ class Postgres(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def unnest_sql(self, expression: exp.Unnest) -> str:
+ if len(expression.expressions) == 1:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ this = annotate_types(expression.expressions[0])
+ if this.is_type("array<json>"):
+ while isinstance(this, exp.Cast):
+ this = this.this
+
+ arg = self.sql(exp.cast(this, exp.DataType.Type.JSON))
+ alias = self.sql(expression, "alias")
+ alias = f" AS {alias}" if alias else ""
+
+ if expression.args.get("offset"):
+ self.unsupported("Unsupported JSON_ARRAY_ELEMENTS with offset")
+
+ return f"JSON_ARRAY_ELEMENTS({arg}){alias}"
+
+ return super().unnest_sql(expression)
+
def bracket_sql(self, expression: exp.Bracket) -> str:
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
if isinstance(expression.this, exp.Array):
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 8429547..3649bd2 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -453,11 +453,32 @@ class Presto(Dialect):
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
- self.unsupported("Struct with key-value definitions is unsupported.")
- return self.function_fallback_sql(expression)
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ expression = annotate_types(expression)
+ values: t.List[str] = []
+ schema: t.List[str] = []
+ unknown_type = False
+
+ for e in expression.expressions:
+ if isinstance(e, exp.PropertyEQ):
+ if e.type and e.type.is_type(exp.DataType.Type.UNKNOWN):
+ unknown_type = True
+ else:
+ schema.append(f"{self.sql(e, 'this')} {self.sql(e.type)}")
+ values.append(self.sql(e, "expression"))
+ else:
+ values.append(self.sql(e))
+
+ size = len(expression.expressions)
- return rename_func("ROW")(self, expression)
+ if not size or len(schema) != size:
+ if unknown_type:
+ self.unsupported(
+ "Cannot convert untyped key-value definitions (try annotate_types)."
+ )
+ return self.func("ROW", *values)
+ return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 2201c78..0db87ec 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -70,6 +70,8 @@ class Redshift(Postgres):
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True),
}
+ SUPPORTS_IMPLICIT_UNNEST = True
+
def _parse_table(
self,
schema: bool = False,
@@ -124,27 +126,6 @@ class Redshift(Postgres):
self._retreat(index)
return None
- def _parse_query_modifiers(
- self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]:
- this = super()._parse_query_modifiers(this)
-
- if this:
- refs = set()
-
- for i, join in enumerate(this.args.get("joins", [])):
- refs.add(
- (
- this.args["from"] if i == 0 else this.args["joins"][i - 1]
- ).this.alias.lower()
- )
-
- table = join.this
- if isinstance(table, exp.Table) and not join.args.get("on"):
- if table.parts[0].name.lower() in refs:
- table.replace(table.to_column())
- return this
-
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
@@ -225,6 +206,18 @@ class Redshift(Postgres):
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
+ def unnest_sql(self, expression: exp.Unnest) -> str:
+ args = expression.expressions
+ num_args = len(args)
+
+ if num_args > 1:
+ self.unsupported(f"Unsupported number of arguments in UNNEST: {num_args}")
+ return ""
+
+ arg = self.sql(seq_get(args, 0))
+ alias = self.expressions(expression.args.get("alias"), key="columns")
+ return f"{arg} AS {alias}" if alias else arg
+
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index c773e50..20fdfb7 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -21,7 +21,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import is_int, seq_get
+from sqlglot.helper import flatten, is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
@@ -66,7 +66,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
return exp.Struct(
expressions=[
- t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
+ exp.PropertyEQ(this=k, expression=v) for k, v in zip(expression.keys, expression.values)
]
)
@@ -409,8 +409,16 @@ class Snowflake(Dialect):
"TERSE OBJECTS": _show_parser("OBJECTS"),
"TABLES": _show_parser("TABLES"),
"TERSE TABLES": _show_parser("TABLES"),
+ "VIEWS": _show_parser("VIEWS"),
+ "TERSE VIEWS": _show_parser("VIEWS"),
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ "IMPORTED KEYS": _show_parser("IMPORTED KEYS"),
+ "TERSE IMPORTED KEYS": _show_parser("IMPORTED KEYS"),
+ "UNIQUE KEYS": _show_parser("UNIQUE KEYS"),
+ "TERSE UNIQUE KEYS": _show_parser("UNIQUE KEYS"),
+ "SEQUENCES": _show_parser("SEQUENCES"),
+ "TERSE SEQUENCES": _show_parser("SEQUENCES"),
"COLUMNS": _show_parser("COLUMNS"),
"USERS": _show_parser("USERS"),
"TERSE USERS": _show_parser("USERS"),
@@ -424,11 +432,13 @@ class Snowflake(Dialect):
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
+ SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
+
def _parse_colon_get_path(
self: parser.Parser, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
while True:
- path = self._parse_bitwise()
+ path = self._parse_bitwise() or self._parse_var(any_token=True)
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
@@ -535,7 +545,7 @@ class Snowflake(Dialect):
return table
def _parse_table_parts(
- self, schema: bool = False, is_db_reference: bool = False
+ self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
if self._match(TokenType.STRING, advance=False):
@@ -603,7 +613,7 @@ class Snowflake(Dialect):
if self._curr:
scope = self._parse_table_parts()
elif self._curr:
- scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE"
+ scope_kind = "SCHEMA" if this in self.SCHEMA_KINDS else "TABLE"
scope = self._parse_table_parts()
return self.expression(
@@ -758,10 +768,6 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
- exp.Struct: lambda self, e: self.func(
- "OBJECT_CONSTRUCT",
- *(arg for expression in e.expressions for arg in expression.flatten()),
- ),
exp.Stuff: rename_func("INSERT"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
@@ -937,3 +943,19 @@ class Snowflake(Dialect):
def cluster_sql(self, expression: exp.Cluster) -> str:
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
+
+ def struct_sql(self, expression: exp.Struct) -> str:
+ keys = []
+ values = []
+
+ for i, e in enumerate(expression.expressions):
+ if isinstance(e, exp.PropertyEQ):
+ keys.append(
+ exp.Literal.string(e.name) if isinstance(e.this, exp.Identifier) else e.this
+ )
+ values.append(e.expression)
+ else:
+ keys.append(exp.Literal.string(f"_{i}"))
+ values.append(e)
+
+ return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 60cf8e1..63eae6e 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -263,14 +263,9 @@ class Spark2(Hive):
CREATE_FUNCTION_RETURN_AS = False
def struct_sql(self, expression: exp.Struct) -> str:
- args = []
- for arg in expression.expressions:
- if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
- args.append(exp.alias_(arg.expression, arg.this.name))
- else:
- args.append(arg)
-
- return self.func("STRUCT", *args)
+ from sqlglot.generator import Generator
+
+ return Generator.struct_sql(self, expression)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 6596c5b..2b17ff9 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -92,6 +92,7 @@ class SQLite(Dialect):
NVL2_SUPPORTED = False
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
+ SUPPORTS_TABLE_ALIAS_COLUMNS = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
@@ -173,6 +174,21 @@ class SQLite(Dialect):
return super().cast_sql(expression)
+ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
+ parent = expression.parent
+ alias = parent and parent.args.get("alias")
+
+ if isinstance(alias, exp.TableAlias) and alias.columns:
+ column_alias = alias.columns[0]
+ alias.set("columns", None)
+ sql = self.sql(
+ exp.select(exp.alias_("value", column_alias)).from_(expression).subquery()
+ )
+ else:
+ sql = super().generateseries_sql(expression)
+
+ return sql
+
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = expression.args.get("unit")
unit = unit.name.upper() if unit else "DAY"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 5955352..b6f491f 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
trim_sql,
)
-from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@@ -63,6 +62,44 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1)
BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias}
+# Unsupported options:
+# - OPTIMIZE FOR ( @variable_name { UNKNOWN | = <literal_constant> } [ , ...n ] )
+# - TABLE HINT
+OPTIONS: parser.OPTIONS_TYPE = {
+ **dict.fromkeys(
+ (
+ "DISABLE_OPTIMIZED_PLAN_FORCING",
+ "FAST",
+ "IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX",
+ "LABEL",
+ "MAXDOP",
+ "MAXRECURSION",
+ "MAX_GRANT_PERCENT",
+ "MIN_GRANT_PERCENT",
+ "NO_PERFORMANCE_SPOOL",
+ "QUERYTRACEON",
+ "RECOMPILE",
+ ),
+ tuple(),
+ ),
+ "CONCAT": ("UNION",),
+ "DISABLE": ("EXTERNALPUSHDOWN", "SCALEOUTEXECUTION"),
+ "EXPAND": ("VIEWS",),
+ "FORCE": ("EXTERNALPUSHDOWN", "ORDER", "SCALEOUTEXECUTION"),
+ "HASH": ("GROUP", "JOIN", "UNION"),
+ "KEEP": ("PLAN",),
+ "KEEPFIXED": ("PLAN",),
+ "LOOP": ("JOIN",),
+ "MERGE": ("JOIN", "UNION"),
+ "OPTIMIZE": (("FOR", "UNKNOWN"),),
+ "ORDER": ("GROUP",),
+ "PARAMETERIZATION": ("FORCED", "SIMPLE"),
+ "ROBUST": ("PLAN",),
+ "USE": ("PLAN",),
+}
+
+OPTIONS_THAT_REQUIRE_EQUAL = ("MAX_GRANT_PERCENT", "MIN_GRANT_PERCENT", "LABEL")
+
def _build_formatted_time(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
@@ -221,19 +258,17 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
# We keep track of the unaliased column projection indexes instead of the expressions
# themselves, because the latter are going to be replaced by new nodes when the aliases
# are added and hence we won't be able to reach these newly added Alias parents
- subqueryable = expression.this
+ query = expression.this
unaliased_column_indexes = (
- i
- for i, c in enumerate(subqueryable.selects)
- if isinstance(c, exp.Column) and not c.alias
+ i for i, c in enumerate(query.selects) if isinstance(c, exp.Column) and not c.alias
)
- qualify_outputs(subqueryable)
+ qualify_outputs(query)
# Preserve the quoting information of columns for newly added Alias nodes
- subqueryable_selects = subqueryable.selects
+ query_selects = query.selects
for select_index in unaliased_column_indexes:
- alias = subqueryable_selects[select_index]
+ alias = query_selects[select_index]
column = alias.this
if isinstance(column.this, exp.Identifier):
alias.args["alias"].set("quoted", column.this.quoted)
@@ -420,7 +455,6 @@ class TSQL(Dialect):
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
- "NVARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
"PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT,
@@ -431,15 +465,24 @@ class TSQL(Dialect):
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"UPDATE STATISTICS": TokenType.COMMAND,
- "VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
+ "OPTION": TokenType.OPTION,
}
class Parser(parser.Parser):
SET_REQUIRES_ASSIGNMENT_DELIMITER = False
+ LOG_DEFAULTS_TO_LN = True
+ ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
+ STRING_ALIASES = True
+ NO_PAREN_IF_COMMANDS = False
+
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ TokenType.OPTION: lambda self: ("options", self._parse_options()),
+ }
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -472,19 +515,7 @@ class TSQL(Dialect):
"TIMEFROMPARTS": _build_timefromparts,
}
- JOIN_HINTS = {
- "LOOP",
- "HASH",
- "MERGE",
- "REMOTE",
- }
-
- VAR_LENGTH_DATATYPES = {
- DataType.Type.NVARCHAR,
- DataType.Type.VARCHAR,
- DataType.Type.CHAR,
- DataType.Type.NCHAR,
- }
+ JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
@@ -496,11 +527,21 @@ class TSQL(Dialect):
TokenType.END: lambda self: self._parse_command(),
}
- LOG_DEFAULTS_TO_LN = True
+ def _parse_options(self) -> t.Optional[t.List[exp.Expression]]:
+ if not self._match(TokenType.OPTION):
+ return None
- ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
- STRING_ALIASES = True
- NO_PAREN_IF_COMMANDS = False
+ def _parse_option() -> t.Optional[exp.Expression]:
+ option = self._parse_var_from_options(OPTIONS)
+ if not option:
+ return None
+
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.QueryOption, this=option, expression=self._parse_primary_or_var()
+ )
+
+ return self._parse_wrapped_csv(_parse_option)
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -576,48 +617,13 @@ class TSQL(Dialect):
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
- to = self._parse_types()
+ this = self._parse_types()
self._match(TokenType.COMMA)
- this = self._parse_conjunction()
-
- if not to or not this:
- return None
-
- # Retrieve length of datatype and override to default if not specified
- if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
- to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
-
- # Check whether a conversion with format is applicable
- if self._match(TokenType.COMMA):
- format_val = self._parse_number()
- format_val_name = format_val.name if format_val else ""
-
- if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING:
- raise ValueError(
- f"CONVERT function at T-SQL does not support format style {format_val_name}"
- )
-
- format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name])
-
- # Check whether the convert entails a string to date format
- if to.this == DataType.Type.DATE:
- return self.expression(exp.StrToDate, this=this, format=format_norm)
- # Check whether the convert entails a string to datetime format
- elif to.this == DataType.Type.DATETIME:
- return self.expression(exp.StrToTime, this=this, format=format_norm)
- # Check whether the convert entails a date to string format
- elif to.this in self.VAR_LENGTH_DATATYPES:
- return self.expression(
- exp.Cast if strict else exp.TryCast,
- to=to,
- this=self.expression(exp.TimeToStr, this=this, format=format_norm),
- safe=safe,
- )
- elif to.this == DataType.Type.TEXT:
- return self.expression(exp.TimeToStr, this=this, format=format_norm)
-
- # Entails a simple cast without any format requirement
- return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
+ args = [this, *self._parse_csv(self._parse_conjunction)]
+ convert = exp.Convert.from_arg_list(args)
+ convert.set("safe", safe)
+ convert.set("strict", strict)
+ return convert
def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None
@@ -683,6 +689,26 @@ class TSQL(Dialect):
return self.expression(exp.UniqueColumnConstraint, this=this)
+ def _parse_partition(self) -> t.Optional[exp.Partition]:
+ if not self._match_text_seq("WITH", "(", "PARTITIONS"):
+ return None
+
+ def parse_range():
+ low = self._parse_bitwise()
+ high = self._parse_bitwise() if self._match_text_seq("TO") else None
+
+ return (
+ self.expression(exp.PartitionRange, this=low, expression=high) if high else low
+ )
+
+ partition = self.expression(
+ exp.Partition, expressions=self._parse_wrapped_csv(parse_range)
+ )
+
+ self._match_r_paren()
+
+ return partition
+
class Generator(generator.Generator):
LIMIT_IS_TOP = True
QUERY_HINTS = False
@@ -728,6 +754,9 @@ class TSQL(Dialect):
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}
+ TYPE_MAPPING.pop(exp.DataType.Type.NCHAR)
+ TYPE_MAPPING.pop(exp.DataType.Type.NVARCHAR)
+
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
@@ -779,6 +808,20 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def convert_sql(self, expression: exp.Convert) -> str:
+ name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
+ return self.func(
+ name, expression.this, expression.expression, expression.args.get("style")
+ )
+
+ def queryoption_sql(self, expression: exp.QueryOption) -> str:
+ option = self.sql(expression, "this")
+ value = self.sql(expression, "expression")
+ if value:
+ optional_equal_sign = "= " if option in OPTIONS_THAT_REQUIRE_EQUAL else ""
+ return f"{option} {optional_equal_sign}{value}"
+ return option
+
def lateral_op(self, expression: exp.Lateral) -> str:
cross_apply = expression.args.get("cross_apply")
if cross_apply is True:
@@ -876,11 +919,10 @@ class TSQL(Dialect):
if ctas_with:
ctas_with = ctas_with.pop()
- subquery = ctas_expression
- if isinstance(subquery, exp.Subqueryable):
- subquery = subquery.subquery()
+ if isinstance(ctas_expression, exp.UNWRAPPED_QUERIES):
+ ctas_expression = ctas_expression.subquery()
- select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True))
+ select_into = exp.select("*").from_(exp.alias_(ctas_expression, "temp", table=True))
select_into.set("into", exp.Into(this=table))
select_into.set("with", ctas_with)
@@ -993,3 +1035,6 @@ class TSQL(Dialect):
this_sql = self.sql(this)
expression_sql = self.sql(expression, "expression")
return self.func(name, this_sql, expression_sql if expression_sql else None)
+
+ def partition_sql(self, expression: exp.Partition) -> str:
+ return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))"
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index c10d640..bda9136 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -119,13 +119,18 @@ def diff(
return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
-LEAF_EXPRESSION_TYPES = (
+# The expression types for which Update edits are allowed.
+UPDATABLE_EXPRESSION_TYPES = (
exp.Boolean,
exp.DataType,
- exp.Identifier,
exp.Literal,
+ exp.Table,
+ exp.Column,
+ exp.Lambda,
)
+IGNORED_LEAF_EXPRESSION_TYPES = (exp.Identifier,)
+
class ChangeDistiller:
"""
@@ -152,8 +157,16 @@ class ChangeDistiller:
self._source = source
self._target = target
- self._source_index = {id(n): n for n, *_ in self._source.bfs()}
- self._target_index = {id(n): n for n, *_ in self._target.bfs()}
+ self._source_index = {
+ id(n): n
+ for n, *_ in self._source.bfs()
+ if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
+ }
+ self._target_index = {
+ id(n): n
+ for n, *_ in self._target.bfs()
+ if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
+ }
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
@@ -170,7 +183,10 @@ class ChangeDistiller:
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
- if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
+ if (
+ not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
+ or source_node == target_node
+ ):
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
@@ -307,17 +323,16 @@ def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
for _, node in expression.iter_expressions():
- has_child_exprs = True
- yield from _get_leaves(node)
+ if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
+ has_child_exprs = True
+ yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
- if type(source) is type(target) and (
- not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
- ):
+ if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -343,7 +358,11 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
if expression:
for a in expression.args.values():
args.extend(ensure_list(a))
- return [a for a in args if isinstance(a, exp.Expression)]
+ return [
+ a
+ for a in args
+ if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES)
+ ]
def _lcs(
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index e4c4040..a411c18 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -78,7 +78,7 @@ class Context:
def sort(self, key) -> None:
def sort_key(row: t.Tuple) -> t.Tuple:
self.set_row(row)
- return self.eval_tuple(key)
+ return tuple((t is None, t) for t in self.eval_tuple(key))
self.table.rows.sort(key=sort_key)
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index c0becbe..a2b23d4 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -142,7 +142,6 @@ class PythonExecutor:
context = self.context({alias: table})
yield context
types = []
-
for row in reader:
if not types:
for v in row:
@@ -150,7 +149,11 @@ class PythonExecutor:
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
- context.set_row(tuple(t(v) for t, v in zip(types, row)))
+
+ # We can't cast empty values ('') to non-string types, so we convert them to None instead
+ context.set_row(
+ tuple(None if (t is not str and v == "") else t(v) for t, v in zip(types, row))
+ )
yield context.table.reader
def join(self, step, context):
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 1408d3c..1a24875 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -548,12 +548,10 @@ class Expression(metaclass=_Expression):
return new_node
@t.overload
- def replace(self, expression: E) -> E:
- ...
+ def replace(self, expression: E) -> E: ...
@t.overload
- def replace(self, expression: None) -> None:
- ...
+ def replace(self, expression: None) -> None: ...
def replace(self, expression):
"""
@@ -913,14 +911,142 @@ class Predicate(Condition):
class DerivedTable(Expression):
@property
def selects(self) -> t.List[Expression]:
- return self.this.selects if isinstance(self.this, Subqueryable) else []
+ return self.this.selects if isinstance(self.this, Query) else []
@property
def named_selects(self) -> t.List[str]:
return [select.output_name for select in self.selects]
-class Unionable(Expression):
+class Query(Expression):
+ def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery:
+ """
+ Returns a `Subquery` that wraps around this query.
+
+ Example:
+ >>> subquery = Select().select("x").from_("tbl").subquery()
+ >>> Select().select("x").from_(subquery).sql()
+ 'SELECT x FROM (SELECT x FROM tbl)'
+
+ Args:
+ alias: an optional alias for the subquery.
+ copy: if `False`, modify this expression instance in-place.
+ """
+ instance = maybe_copy(self, copy)
+ if not isinstance(alias, Expression):
+ alias = TableAlias(this=to_identifier(alias)) if alias else None
+
+ return Subquery(this=instance, alias=alias)
+
+ def limit(
+ self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
+ ) -> Select:
+ """
+ Adds a LIMIT clause to this query.
+
+ Example:
+ >>> select("1").union(select("1")).limit(1).sql()
+ 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
+
+ Args:
+ expression: the SQL code string to parse.
+ This can also be an integer.
+ If a `Limit` instance is passed, it will be used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Limit`.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ A limited Select expression.
+ """
+ return (
+ select("*")
+ .from_(self.subquery(alias="_l_0", copy=copy))
+ .limit(expression, dialect=dialect, copy=False, **opts)
+ )
+
+ @property
+ def ctes(self) -> t.List[CTE]:
+ """Returns a list of all the CTEs attached to this query."""
+ with_ = self.args.get("with")
+ return with_.expressions if with_ else []
+
+ @property
+ def selects(self) -> t.List[Expression]:
+ """Returns the query's projections."""
+ raise NotImplementedError("Query objects must implement `selects`")
+
+ @property
+ def named_selects(self) -> t.List[str]:
+ """Returns the output names of the query's projections."""
+ raise NotImplementedError("Query objects must implement `named_selects`")
+
+ def select(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Query:
+ """
+ Append to or set the SELECT expressions.
+
+ Example:
+ >>> Select().select("x", "y").sql()
+ 'SELECT x, y'
+
+ Args:
+ *expressions: the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ append: if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ The modified Query expression.
+ """
+ raise NotImplementedError("Query objects must implement `select`")
+
+ def with_(
+ self,
+ alias: ExpOrStr,
+ as_: ExpOrStr,
+ recursive: t.Optional[bool] = None,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Query:
+ """
+ Append to or set the common table expressions.
+
+ Example:
+ >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql()
+ 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
+
+ Args:
+ alias: the SQL code string to parse as the table name.
+ If an `Expression` instance is passed, this is used as-is.
+ as_: the SQL code string to parse as the table expression.
+ If an `Expression` instance is passed, it will be used as-is.
+ recursive: set the RECURSIVE part of the expression. Defaults to `False`.
+ append: if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ The modified expression.
+ """
+ return _apply_cte_builder(
+ self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
+ )
+
def union(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Union:
@@ -946,7 +1072,7 @@ class Unionable(Expression):
def intersect(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
- ) -> Unionable:
+ ) -> Intersect:
"""
Builds an INTERSECT expression.
@@ -969,7 +1095,7 @@ class Unionable(Expression):
def except_(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
- ) -> Unionable:
+ ) -> Except:
"""
Builds an EXCEPT expression.
@@ -991,7 +1117,7 @@ class Unionable(Expression):
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
-class UDTF(DerivedTable, Unionable):
+class UDTF(DerivedTable):
@property
def selects(self) -> t.List[Expression]:
alias = self.args.get("alias")
@@ -1017,23 +1143,23 @@ class Refresh(Expression):
class DDL(Expression):
@property
- def ctes(self):
+ def ctes(self) -> t.List[CTE]:
+ """Returns a list of all the CTEs attached to this statement."""
with_ = self.args.get("with")
- if not with_:
- return []
- return with_.expressions
+ return with_.expressions if with_ else []
@property
- def named_selects(self) -> t.List[str]:
- if isinstance(self.expression, Subqueryable):
- return self.expression.named_selects
- return []
+ def selects(self) -> t.List[Expression]:
+ """If this statement contains a query (e.g. a CTAS), this returns the query's projections."""
+ return self.expression.selects if isinstance(self.expression, Query) else []
@property
- def selects(self) -> t.List[Expression]:
- if isinstance(self.expression, Subqueryable):
- return self.expression.selects
- return []
+ def named_selects(self) -> t.List[str]:
+ """
+ If this statement contains a query (e.g. a CTAS), this returns the output
+ names of the query's projections.
+ """
+ return self.expression.named_selects if isinstance(self.expression, Query) else []
class DML(Expression):
@@ -1096,6 +1222,19 @@ class Create(DDL):
return kind and kind.upper()
+class TruncateTable(Expression):
+ arg_types = {
+ "expressions": True,
+ "is_database": False,
+ "exists": False,
+ "only": False,
+ "cluster": False,
+ "identity": False,
+ "option": False,
+ "partition": False,
+ }
+
+
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
@@ -1271,6 +1410,10 @@ class ColumnDef(Expression):
def constraints(self) -> t.List[ColumnConstraint]:
return self.args.get("constraints") or []
+ @property
+ def kind(self) -> t.Optional[DataType]:
+ return self.args.get("kind")
+
class AlterColumn(Expression):
arg_types = {
@@ -1367,7 +1510,7 @@ class CharacterSetColumnConstraint(ColumnConstraintKind):
class CheckColumnConstraint(ColumnConstraintKind):
- pass
+ arg_types = {"this": True, "enforced": False}
class ClusteredColumnConstraint(ColumnConstraintKind):
@@ -1776,6 +1919,10 @@ class Partition(Expression):
arg_types = {"expressions": True}
+class PartitionRange(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
class Fetch(Expression):
arg_types = {
"direction": False,
@@ -2173,6 +2320,10 @@ class LocationProperty(Property):
arg_types = {"this": True}
+class LockProperty(Property):
+ arg_types = {"this": True}
+
+
class LockingProperty(Property):
arg_types = {
"this": False,
@@ -2310,7 +2461,7 @@ class StabilityProperty(Property):
class TemporaryProperty(Property):
- arg_types = {}
+ arg_types = {"this": False}
class TransformModelProperty(Property):
@@ -2356,6 +2507,7 @@ class Properties(Expression):
"FORMAT": FileFormatProperty,
"LANGUAGE": LanguageProperty,
"LOCATION": LocationProperty,
+ "LOCK": LockProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
"ROW_FORMAT": RowFormatProperty,
@@ -2445,102 +2597,13 @@ class Tuple(Expression):
)
-class Subqueryable(Unionable):
- def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery:
- """
- Convert this expression to an aliased expression that can be used as a Subquery.
-
- Example:
- >>> subquery = Select().select("x").from_("tbl").subquery()
- >>> Select().select("x").from_(subquery).sql()
- 'SELECT x FROM (SELECT x FROM tbl)'
-
- Args:
- alias (str | Identifier): an optional alias for the subquery
- copy (bool): if `False`, modify this expression instance in-place.
-
- Returns:
- Alias: the subquery
- """
- instance = maybe_copy(self, copy)
- if not isinstance(alias, Expression):
- alias = TableAlias(this=to_identifier(alias)) if alias else None
-
- return Subquery(this=instance, alias=alias)
-
- def limit(
- self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
- ) -> Select:
- raise NotImplementedError
-
- @property
- def ctes(self):
- with_ = self.args.get("with")
- if not with_:
- return []
- return with_.expressions
-
- @property
- def selects(self) -> t.List[Expression]:
- raise NotImplementedError("Subqueryable objects must implement `selects`")
-
- @property
- def named_selects(self) -> t.List[str]:
- raise NotImplementedError("Subqueryable objects must implement `named_selects`")
-
- def select(
- self,
- *expressions: t.Optional[ExpOrStr],
- append: bool = True,
- dialect: DialectType = None,
- copy: bool = True,
- **opts,
- ) -> Subqueryable:
- raise NotImplementedError("Subqueryable objects must implement `select`")
-
- def with_(
- self,
- alias: ExpOrStr,
- as_: ExpOrStr,
- recursive: t.Optional[bool] = None,
- append: bool = True,
- dialect: DialectType = None,
- copy: bool = True,
- **opts,
- ) -> Subqueryable:
- """
- Append to or set the common table expressions.
-
- Example:
- >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql()
- 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
-
- Args:
- alias: the SQL code string to parse as the table name.
- If an `Expression` instance is passed, this is used as-is.
- as_: the SQL code string to parse as the table expression.
- If an `Expression` instance is passed, it will be used as-is.
- recursive: set the RECURSIVE part of the expression. Defaults to `False`.
- append: if `True`, add to any existing expressions.
- Otherwise, this resets the expressions.
- dialect: the dialect used to parse the input expression.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- The modified expression.
- """
- return _apply_cte_builder(
- self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
- )
-
-
QUERY_MODIFIERS = {
"match": False,
"laterals": False,
"joins": False,
"connect": False,
"pivots": False,
+ "prewhere": False,
"where": False,
"group": False,
"having": False,
@@ -2556,9 +2619,16 @@ QUERY_MODIFIERS = {
"sample": False,
"settings": False,
"format": False,
+ "options": False,
}
+# https://learn.microsoft.com/en-us/sql/t-sql/queries/option-clause-transact-sql?view=sql-server-ver16
+# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-query?view=sql-server-ver16
+class QueryOption(Expression):
+ arg_types = {"this": True, "expression": False}
+
+
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16
class WithTableHint(Expression):
arg_types = {"expressions": True}
@@ -2590,6 +2660,7 @@ class Table(Expression):
"pattern": False,
"ordinality": False,
"when": False,
+ "only": False,
}
@property
@@ -2638,7 +2709,7 @@ class Table(Expression):
return col
-class Union(Subqueryable):
+class Union(Query):
arg_types = {
"with": False,
"this": True,
@@ -2648,34 +2719,6 @@ class Union(Subqueryable):
**QUERY_MODIFIERS,
}
- def limit(
- self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
- ) -> Select:
- """
- Set the LIMIT expression.
-
- Example:
- >>> select("1").union(select("1")).limit(1).sql()
- 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
-
- Args:
- expression: the SQL code string to parse.
- This can also be an integer.
- If a `Limit` instance is passed, this is used as-is.
- If another `Expression` instance is passed, it will be wrapped in a `Limit`.
- dialect: the dialect used to parse the input expression.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- The limited subqueryable.
- """
- return (
- select("*")
- .from_(self.subquery(alias="_l_0", copy=copy))
- .limit(expression, dialect=dialect, copy=False, **opts)
- )
-
def select(
self,
*expressions: t.Optional[ExpOrStr],
@@ -2684,26 +2727,7 @@ class Union(Subqueryable):
copy: bool = True,
**opts,
) -> Union:
- """Append to or set the SELECT of the union recursively.
-
- Example:
- >>> from sqlglot import parse_one
- >>> parse_one("select a from x union select a from y union select a from z").select("b").sql()
- 'SELECT a, b FROM x UNION SELECT a, b FROM y UNION SELECT a, b FROM z'
-
- Args:
- *expressions: the SQL code strings to parse.
- If an `Expression` instance is passed, it will be used as-is.
- append: if `True`, add to any existing expressions.
- Otherwise, this resets the expressions.
- dialect: the dialect used to parse the input expressions.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- Union: the modified expression.
- """
- this = self.copy() if copy else self
+ this = maybe_copy(self, copy)
this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
this.expression.unnest().select(
*expressions, append=append, dialect=dialect, copy=False, **opts
@@ -2800,7 +2824,7 @@ class Lock(Expression):
arg_types = {"update": True, "expressions": False, "wait": False}
-class Select(Subqueryable):
+class Select(Query):
arg_types = {
"with": False,
"kind": False,
@@ -3011,25 +3035,6 @@ class Select(Subqueryable):
def limit(
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
) -> Select:
- """
- Set the LIMIT expression.
-
- Example:
- >>> Select().from_("tbl").select("x").limit(10).sql()
- 'SELECT x FROM tbl LIMIT 10'
-
- Args:
- expression: the SQL code string to parse.
- This can also be an integer.
- If a `Limit` instance is passed, this is used as-is.
- If another `Expression` instance is passed, it will be wrapped in a `Limit`.
- dialect: the dialect used to parse the input expression.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- Select: the modified expression.
- """
return _apply_builder(
expression=expression,
instance=self,
@@ -3084,31 +3089,13 @@ class Select(Subqueryable):
copy: bool = True,
**opts,
) -> Select:
- """
- Append to or set the SELECT expressions.
-
- Example:
- >>> Select().select("x", "y").sql()
- 'SELECT x, y'
-
- Args:
- *expressions: the SQL code strings to parse.
- If an `Expression` instance is passed, it will be used as-is.
- append: if `True`, add to any existing expressions.
- Otherwise, this resets the expressions.
- dialect: the dialect used to parse the input expressions.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- The modified Select expression.
- """
return _apply_list_builder(
*expressions,
instance=self,
arg="expressions",
append=append,
dialect=dialect,
+ into=Expression,
copy=copy,
**opts,
)
@@ -3416,12 +3403,8 @@ class Select(Subqueryable):
The new Create expression.
"""
instance = maybe_copy(self, copy)
- table_expression = maybe_parse(
- table,
- into=Table,
- dialect=dialect,
- **opts,
- )
+ table_expression = maybe_parse(table, into=Table, dialect=dialect, **opts)
+
properties_expression = None
if properties:
properties_expression = Properties.from_dict(properties)
@@ -3493,7 +3476,10 @@ class Select(Subqueryable):
return self.expressions
-class Subquery(DerivedTable, Unionable):
+UNWRAPPED_QUERIES = (Select, Union)
+
+
+class Subquery(DerivedTable, Query):
arg_types = {
"this": True,
"alias": False,
@@ -3502,9 +3488,7 @@ class Subquery(DerivedTable, Unionable):
}
def unnest(self):
- """
- Returns the first non subquery.
- """
+ """Returns the first non subquery."""
expression = self
while isinstance(expression, Subquery):
expression = expression.this
@@ -3516,6 +3500,18 @@ class Subquery(DerivedTable, Unionable):
expression = t.cast(Subquery, expression.parent)
return expression
+ def select(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Subquery:
+ this = maybe_copy(self, copy)
+ this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
+ return this
+
@property
def is_wrapper(self) -> bool:
"""
@@ -3603,6 +3599,10 @@ class WindowSpec(Expression):
}
+class PreWhere(Expression):
+ pass
+
+
class Where(Expression):
pass
@@ -3646,6 +3646,10 @@ class Boolean(Condition):
class DataTypeParam(Expression):
arg_types = {"this": True, "expression": False}
+ @property
+ def name(self) -> str:
+ return self.this.name
+
class DataType(Expression):
arg_types = {
@@ -3926,11 +3930,17 @@ class Rollback(Expression):
class AlterTable(Expression):
- arg_types = {"this": True, "actions": True, "exists": False, "only": False}
+ arg_types = {
+ "this": True,
+ "actions": True,
+ "exists": False,
+ "only": False,
+ "options": False,
+ }
class AddConstraint(Expression):
- arg_types = {"this": False, "expression": False, "enforced": False}
+ arg_types = {"expressions": True}
class DropPartition(Expression):
@@ -3996,6 +4006,10 @@ class Overlaps(Binary):
class Dot(Binary):
@property
+ def is_star(self) -> bool:
+ return self.expression.is_star
+
+ @property
def name(self) -> str:
return self.expression.name
@@ -4390,6 +4404,10 @@ class Anonymous(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
+ @property
+ def name(self) -> str:
+ return self.this if isinstance(self.this, str) else self.this.name
+
class AnonymousAggFunc(AggFunc):
arg_types = {"this": True, "expressions": False}
@@ -4433,8 +4451,13 @@ class ToChar(Func):
arg_types = {"this": True, "format": False, "nlsparam": False}
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax
+class Convert(Func):
+ arg_types = {"this": True, "expression": True, "style": False}
+
+
class GenerateSeries(Func):
- arg_types = {"start": True, "end": True, "step": False}
+ arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False}
class ArrayAgg(AggFunc):
@@ -4624,6 +4647,11 @@ class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
+# https://docs.oracle.com/cd/B13789_01/server.101/b10759/operators004.htm#i1035022
+class ConnectByRoot(Func):
+ pass
+
+
class Count(AggFunc):
arg_types = {"this": False, "expressions": False}
is_var_len_args = True
@@ -5197,6 +5225,10 @@ class Month(Func):
pass
+class AddMonths(Func):
+ arg_types = {"this": True, "expression": True}
+
+
class Nvl2(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -5313,6 +5345,10 @@ class SHA2(Func):
arg_types = {"this": True, "length": False}
+class Sign(Func):
+ _sql_names = ["SIGN", "SIGNUM"]
+
+
class SortArray(Func):
arg_types = {"this": True, "asc": False}
@@ -5554,7 +5590,13 @@ class Use(Expression):
class Merge(Expression):
- arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False}
+ arg_types = {
+ "this": True,
+ "using": True,
+ "on": True,
+ "expressions": True,
+ "with": False,
+ }
class When(Func):
@@ -5587,8 +5629,7 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
-) -> E:
- ...
+) -> E: ...
@t.overload
@@ -5600,8 +5641,7 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
-) -> E:
- ...
+) -> E: ...
def maybe_parse(
@@ -5653,13 +5693,11 @@ def maybe_parse(
@t.overload
-def maybe_copy(instance: None, copy: bool = True) -> None:
- ...
+def maybe_copy(instance: None, copy: bool = True) -> None: ...
@t.overload
-def maybe_copy(instance: E, copy: bool = True) -> E:
- ...
+def maybe_copy(instance: E, copy: bool = True) -> E: ...
def maybe_copy(instance, copy=True):
@@ -6282,15 +6320,13 @@ SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
-def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
- ...
+def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
@t.overload
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
-) -> Identifier:
- ...
+) -> Identifier: ...
def to_identifier(name, quoted=None, copy=True):
@@ -6362,13 +6398,11 @@ def to_interval(interval: str | Literal) -> Interval:
@t.overload
-def to_table(sql_path: str | Table, **kwargs) -> Table:
- ...
+def to_table(sql_path: str | Table, **kwargs) -> Table: ...
@t.overload
-def to_table(sql_path: None, **kwargs) -> None:
- ...
+def to_table(sql_path: None, **kwargs) -> None: ...
def to_table(
@@ -6929,7 +6963,7 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
if isinstance(node, Placeholder):
if node.name:
new_name = kwargs.get(node.name)
- if new_name:
+ if new_name is not None:
return convert(new_name)
else:
try:
@@ -6943,7 +6977,7 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
def expand(
expression: Expression,
- sources: t.Dict[str, Subqueryable],
+ sources: t.Dict[str, Query],
dialect: DialectType = None,
copy: bool = True,
) -> Expression:
@@ -6959,7 +6993,7 @@ def expand(
Args:
expression: The expression to expand.
- sources: A dictionary of name to Subqueryables.
+ sources: A dictionary of name to Queries.
dialect: The dialect of the sources dict.
copy: Whether to copy the expression during transformation. Defaults to True.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 4bb5005..e6f5c4b 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -73,17 +73,16 @@ class Generator(metaclass=_Generator):
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
**JSON_PATH_PART_TRANSFORMS,
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
- exp.CaseSpecificColumnConstraint: lambda self,
+ exp.CaseSpecificColumnConstraint: lambda _,
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self,
e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
- exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self,
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
- exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
+ exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
@@ -91,8 +90,8 @@ class Generator(metaclass=_Generator):
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
- exp.ExternalProperty: lambda self, e: "EXTERNAL",
- exp.HeapProperty: lambda self, e: "HEAP",
+ exp.ExternalProperty: lambda *_: "EXTERNAL",
+ exp.HeapProperty: lambda *_: "HEAP",
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
@@ -105,13 +104,13 @@ class Generator(metaclass=_Generator):
),
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
- exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
- exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
+ exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG",
+ exp.MaterializedProperty: lambda *_: "MATERIALIZED",
exp.NonClusteredColumnConstraint: lambda self,
e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
- exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
- exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
- exp.OnCommitProperty: lambda self,
+ exp.NoPrimaryIndexProperty: lambda *_: "NO PRIMARY INDEX",
+ exp.NotForReplicationColumnConstraint: lambda *_: "NOT FOR REPLICATION",
+ exp.OnCommitProperty: lambda _,
e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
@@ -122,21 +121,21 @@ class Generator(metaclass=_Generator):
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
- exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
+ exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
- exp.SqlReadWriteProperty: lambda self, e: e.name,
- exp.SqlSecurityProperty: lambda self,
+ exp.SqlReadWriteProperty: lambda _, e: e.name,
+ exp.SqlSecurityProperty: lambda _,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
- exp.StabilityProperty: lambda self, e: e.name,
- exp.TemporaryProperty: lambda self, e: "TEMPORARY",
+ exp.StabilityProperty: lambda _, e: e.name,
+ exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
- exp.TransientProperty: lambda self, e: "TRANSIENT",
- exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE",
+ exp.TransientProperty: lambda *_: "TRANSIENT",
+ exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
- exp.VolatileProperty: lambda self, e: "VOLATILE",
+ exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
}
@@ -356,6 +355,7 @@ class Generator(metaclass=_Generator):
STRUCT_DELIMITER = ("<", ">")
PARAMETER_TOKEN = "@"
+ NAMED_PLACEHOLDER_TOKEN = ":"
PROPERTIES_LOCATION = {
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
@@ -388,6 +388,7 @@ class Generator(metaclass=_Generator):
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.LockProperty: exp.Properties.Location.POST_SCHEMA,
exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
exp.LogProperty: exp.Properties.Location.POST_NAME,
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
@@ -459,11 +460,16 @@ class Generator(metaclass=_Generator):
exp.Paren,
)
+ PARAMETERIZABLE_TEXT_TYPES = {
+ exp.DataType.Type.NVARCHAR,
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.CHAR,
+ exp.DataType.Type.NCHAR,
+ }
+
# Expressions that need to have all CTEs under them bubbled up to them
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
- KEY_VALUE_DEFINITIONS = (exp.EQ, exp.PropertyEQ, exp.Slice)
-
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
@@ -630,7 +636,7 @@ class Generator(metaclass=_Generator):
this_sql = self.indent(
(
self.sql(expression)
- if isinstance(expression, (exp.Select, exp.Union))
+ if isinstance(expression, exp.UNWRAPPED_QUERIES)
else self.sql(expression, "this")
),
level=1,
@@ -1535,8 +1541,8 @@ class Generator(metaclass=_Generator):
expr = self.sql(expression, "expression")
return f"{this} ({kind} => {expr})"
- def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
- table = ".".join(
+ def table_parts(self, expression: exp.Table) -> str:
+ return ".".join(
self.sql(part)
for part in (
expression.args.get("catalog"),
@@ -1546,6 +1552,9 @@ class Generator(metaclass=_Generator):
if part is not None
)
+ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
+ table = self.table_parts(expression)
+ only = "ONLY " if expression.args.get("only") else ""
version = self.sql(expression, "version")
version = f" {version}" if version else ""
alias = self.sql(expression, "alias")
@@ -1572,7 +1581,7 @@ class Generator(metaclass=_Generator):
if when:
table = f"{table} {when}"
- return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
+ return f"{only}{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self,
@@ -1681,7 +1690,7 @@ class Generator(metaclass=_Generator):
alias_node = expression.args.get("alias")
column_names = alias_node and alias_node.columns
- selects: t.List[exp.Subqueryable] = []
+ selects: t.List[exp.Query] = []
for i, tup in enumerate(expression.expressions):
row = tup.expressions
@@ -1697,10 +1706,8 @@ class Generator(metaclass=_Generator):
# This may result in poor performance for large-cardinality `VALUES` tables, due to
# the deep nesting of the resulting exp.Unions. If this is a problem, either increase
# `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`.
- subqueryable = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects)
- return self.subquery_sql(
- subqueryable.subquery(alias_node and alias_node.this, copy=False)
- )
+ query = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects)
+ return self.subquery_sql(query.subquery(alias_node and alias_node.this, copy=False))
alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else ""
unions = " UNION ALL ".join(self.sql(select) for select in selects)
@@ -1854,7 +1861,7 @@ class Generator(metaclass=_Generator):
]
args_sql = ", ".join(self.sql(e) for e in args)
- args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
+ args_sql = f"({args_sql})" if top and any(not e.is_number for e in args) else args_sql
expressions = self.expressions(expression, flat=True)
expressions = f" BY {expressions}" if expressions else ""
@@ -2070,12 +2077,17 @@ class Generator(metaclass=_Generator):
else []
)
+ options = self.expressions(expression, key="options")
+ if options:
+ options = f" OPTION{self.wrap(options)}"
+
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "connect"),
self.sql(expression, "match"),
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
+ self.sql(expression, "prewhere"),
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
@@ -2083,9 +2095,13 @@ class Generator(metaclass=_Generator):
self.sql(expression, "order"),
*offset_limit_modifiers,
*self.after_limit_modifiers(expression),
+ options,
sep="",
)
+ def queryoption_sql(self, expression: exp.QueryOption) -> str:
+ return ""
+
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]:
@@ -2140,9 +2156,9 @@ class Generator(metaclass=_Generator):
self.sql(
exp.Struct(
expressions=[
- exp.column(e.output_name).eq(
- e.this if isinstance(e, exp.Alias) else e
- )
+ exp.PropertyEQ(this=e.args.get("alias"), expression=e.this)
+ if isinstance(e, exp.Alias)
+ else e
for e in expression.expressions
]
)
@@ -2204,7 +2220,7 @@ class Generator(metaclass=_Generator):
return f"@@{kind}{this}"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
- return f":{expression.name}" if expression.name else "?"
+ return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?"
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
@@ -2261,6 +2277,9 @@ class Generator(metaclass=_Generator):
return f"UNNEST({args}){suffix}"
+ def prewhere_sql(self, expression: exp.PreWhere) -> str:
+ return ""
+
def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('WHERE')}{self.sep()}{this}"
@@ -2326,7 +2345,7 @@ class Generator(metaclass=_Generator):
def any_sql(self, expression: exp.Any) -> str:
this = self.sql(expression, "this")
- if isinstance(expression.this, exp.Subqueryable):
+ if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
this = self.wrap(this)
return f"ANY {this}"
@@ -2568,7 +2587,7 @@ class Generator(metaclass=_Generator):
is_global = " GLOBAL" if expression.args.get("is_global") else ""
if query:
- in_sql = self.wrap(query)
+ in_sql = self.wrap(self.sql(query))
elif unnest:
in_sql = self.in_unnest_op(unnest)
elif field:
@@ -2610,7 +2629,7 @@ class Generator(metaclass=_Generator):
return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str:
- return self.func(expression.name, *expression.expressions)
+ return self.func(self.sql(expression, "this"), *expression.expressions)
def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
@@ -2822,7 +2841,9 @@ class Generator(metaclass=_Generator):
exists = " IF EXISTS" if expression.args.get("exists") else ""
only = " ONLY" if expression.args.get("only") else ""
- return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
+ options = self.expressions(expression, key="options")
+ options = f", {options}" if options else ""
+ return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}{options}"
def add_column_sql(self, expression: exp.AlterTable) -> str:
if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
@@ -2839,15 +2860,7 @@ class Generator(metaclass=_Generator):
return f"DROP{exists}{expressions}"
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
- this = self.sql(expression, "this")
- expression_ = self.sql(expression, "expression")
- add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
-
- enforced = expression.args.get("enforced")
- if enforced is not None:
- return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
-
- return f"{add_constraint} {expression_}"
+ return f"ADD {self.expressions(expression)}"
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
@@ -3296,6 +3309,10 @@ class Generator(metaclass=_Generator):
self.unsupported("Unsupported index constraint option.")
return ""
+ def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
+ enforced = " ENFORCED" if expression.args.get("enforced") else ""
+ return f"CHECK ({self.sql(expression, 'this')}){enforced}"
+
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
kind = self.sql(expression, "kind")
kind = f"{kind} INDEX" if kind else "INDEX"
@@ -3452,9 +3469,87 @@ class Generator(metaclass=_Generator):
return expression
- def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
- return [
- exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
- for value in values
- if value
- ]
+ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
+ expression.set("is_end_exclusive", None)
+ return self.function_fallback_sql(expression)
+
+ def struct_sql(self, expression: exp.Struct) -> str:
+ expression.set(
+ "expressions",
+ [
+ exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
+ for e in expression.expressions
+ ],
+ )
+
+ return self.function_fallback_sql(expression)
+
+ def partitionrange_sql(self, expression: exp.PartitionRange) -> str:
+ low = self.sql(expression, "this")
+ high = self.sql(expression, "expression")
+
+ return f"{low} TO {high}"
+
+ def truncatetable_sql(self, expression: exp.TruncateTable) -> str:
+ target = "DATABASE" if expression.args.get("is_database") else "TABLE"
+ tables = f" {self.expressions(expression)}"
+
+ exists = " IF EXISTS" if expression.args.get("exists") else ""
+
+ on_cluster = self.sql(expression, "cluster")
+ on_cluster = f" {on_cluster}" if on_cluster else ""
+
+ identity = self.sql(expression, "identity")
+ identity = f" {identity} IDENTITY" if identity else ""
+
+ option = self.sql(expression, "option")
+ option = f" {option}" if option else ""
+
+ partition = self.sql(expression, "partition")
+ partition = f" {partition}" if partition else ""
+
+ return f"TRUNCATE {target}{exists}{tables}{on_cluster}{identity}{option}{partition}"
+
+ # This transpiles T-SQL's CONVERT function
+ # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16
+ def convert_sql(self, expression: exp.Convert) -> str:
+ to = expression.this
+ value = expression.expression
+ style = expression.args.get("style")
+ safe = expression.args.get("safe")
+ strict = expression.args.get("strict")
+
+ if not to or not value:
+ return ""
+
+ # Retrieve length of datatype and override to default if not specified
+ if not seq_get(to.expressions, 0) and to.this in self.PARAMETERIZABLE_TEXT_TYPES:
+ to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
+
+ transformed: t.Optional[exp.Expression] = None
+ cast = exp.Cast if strict else exp.TryCast
+
+ # Check whether a conversion with format (T-SQL calls this 'style') is applicable
+ if isinstance(style, exp.Literal) and style.is_int:
+ from sqlglot.dialects.tsql import TSQL
+
+ style_value = style.name
+ converted_style = TSQL.CONVERT_FORMAT_MAPPING.get(style_value)
+ if not converted_style:
+ self.unsupported(f"Unsupported T-SQL 'style' value: {style_value}")
+
+ fmt = exp.Literal.string(converted_style)
+
+ if to.this == exp.DataType.Type.DATE:
+ transformed = exp.StrToDate(this=value, format=fmt)
+ elif to.this == exp.DataType.Type.DATETIME:
+ transformed = exp.StrToTime(this=value, format=fmt)
+ elif to.this in self.PARAMETERIZABLE_TEXT_TYPES:
+ transformed = cast(this=exp.TimeToStr(this=value, format=fmt), to=to, safe=safe)
+ elif to.this == exp.DataType.Type.TEXT:
+ transformed = exp.TimeToStr(this=value, format=fmt)
+
+ if not transformed:
+ transformed = cast(this=value, to=to, safe=safe)
+
+ return self.sql(transformed)
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 35a4586..0d4547f 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
@t.overload
-def ensure_list(value: t.Collection[T]) -> t.List[T]:
- ...
+def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
@t.overload
-def ensure_list(value: T) -> t.List[T]:
- ...
+def ensure_list(value: T) -> t.List[T]: ...
def ensure_list(value):
@@ -81,13 +79,11 @@ def ensure_list(value):
@t.overload
-def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
- ...
+def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
@t.overload
-def ensure_collection(value: T) -> t.Collection[T]:
- ...
+def ensure_collection(value: T) -> t.Collection[T]: ...
def ensure_collection(value):
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index f10fbb9..eb428dc 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -1,16 +1,19 @@
from __future__ import annotations
import json
+import logging
import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
-from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
+from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
+logger = logging.getLogger("sqlglot")
+
@dataclass(frozen=True)
class Node:
@@ -18,7 +21,8 @@ class Node:
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
- alias: str = ""
+ source_name: str = ""
+ reference_node_name: str = ""
def walk(self) -> t.Iterator[Node]:
yield self
@@ -67,7 +71,7 @@ def lineage(
column: str | exp.Column,
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
- sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
+ sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
dialect: DialectType = None,
**kwargs,
) -> Node:
@@ -86,14 +90,12 @@ def lineage(
"""
expression = maybe_parse(sql, dialect=dialect)
+ column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
if sources:
expression = exp.expand(
expression,
- {
- k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
- for k, v in sources.items()
- },
+ {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
dialect=dialect,
)
@@ -109,122 +111,141 @@ def lineage(
if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
- def to_node(
- column: str | int,
- scope: Scope,
- scope_name: t.Optional[str] = None,
- upstream: t.Optional[Node] = None,
- alias: t.Optional[str] = None,
- ) -> Node:
- aliases = {
- dt.alias: dt.comments[0].split()[1]
- for dt in scope.derived_tables
- if dt.comments and dt.comments[0].startswith("source: ")
- }
+ if not any(select.alias_or_name == column for select in scope.expression.selects):
+ raise SqlglotError(f"Cannot find column '{column}' in query.")
- # Find the specific select clause that is the source of the column we want.
- # This can either be a specific, named select or a generic `*` clause.
- select = (
- scope.expression.selects[column]
- if isinstance(column, int)
- else next(
- (select for select in scope.expression.selects if select.alias_or_name == column),
- exp.Star() if scope.expression.is_star else scope.expression,
- )
- )
+ return to_node(column, scope, dialect)
- if isinstance(scope.expression, exp.Union):
- upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
-
- index = (
- column
- if isinstance(column, int)
- else next(
- (
- i
- for i, select in enumerate(scope.expression.selects)
- if select.alias_or_name == column or select.is_star
- ),
- -1, # mypy will not allow a None here, but a negative index should never be returned
- )
- )
-
- if index == -1:
- raise ValueError(f"Could not find {column} in {scope.expression}")
- for s in scope.union_scopes:
- to_node(index, scope=s, upstream=upstream, alias=alias)
-
- return upstream
-
- if isinstance(scope.expression, exp.Select):
- # For better ergonomics in our node labels, replace the full select with
- # a version that has only the column we care about.
- # "x", SELECT x, y FROM foo
- # => "x", SELECT x FROM foo
- source = t.cast(exp.Expression, scope.expression.select(select, append=False))
- else:
- source = scope.expression
-
- # Create the node for this step in the lineage chain, and attach it to the previous one.
- node = Node(
- name=f"{scope_name}.{column}" if scope_name else str(column),
- source=source,
- expression=select,
- alias=alias or "",
+def to_node(
+ column: str | int,
+ scope: Scope,
+ dialect: DialectType,
+ scope_name: t.Optional[str] = None,
+ upstream: t.Optional[Node] = None,
+ source_name: t.Optional[str] = None,
+ reference_node_name: t.Optional[str] = None,
+) -> Node:
+ source_names = {
+ dt.alias: dt.comments[0].split()[1]
+ for dt in scope.derived_tables
+ if dt.comments and dt.comments[0].startswith("source: ")
+ }
+
+ # Find the specific select clause that is the source of the column we want.
+ # This can either be a specific, named select or a generic `*` clause.
+ select = (
+ scope.expression.selects[column]
+ if isinstance(column, int)
+ else next(
+ (select for select in scope.expression.selects if select.alias_or_name == column),
+ exp.Star() if scope.expression.is_star else scope.expression,
)
+ )
- if upstream:
- upstream.downstream.append(node)
+ if isinstance(scope.expression, exp.Union):
+ upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
- subquery_scopes = {
- id(subquery_scope.expression): subquery_scope
- for subquery_scope in scope.subquery_scopes
- }
+ index = (
+ column
+ if isinstance(column, int)
+ else next(
+ (
+ i
+ for i, select in enumerate(scope.expression.selects)
+ if select.alias_or_name == column or select.is_star
+ ),
+ -1, # mypy will not allow a None here, but a negative index should never be returned
+ )
+ )
- for subquery in find_all_in_scope(select, exp.Subqueryable):
- subquery_scope = subquery_scopes[id(subquery)]
+ if index == -1:
+ raise ValueError(f"Could not find {column} in {scope.expression}")
+
+ for s in scope.union_scopes:
+ to_node(
+ index,
+ scope=s,
+ dialect=dialect,
+ upstream=upstream,
+ source_name=source_name,
+ reference_node_name=reference_node_name,
+ )
- for name in subquery.named_selects:
- to_node(name, scope=subquery_scope, upstream=node)
+ return upstream
+
+ if isinstance(scope.expression, exp.Select):
+ # For better ergonomics in our node labels, replace the full select with
+ # a version that has only the column we care about.
+ # "x", SELECT x, y FROM foo
+ # => "x", SELECT x FROM foo
+ source = t.cast(exp.Expression, scope.expression.select(select, append=False))
+ else:
+ source = scope.expression
+
+ # Create the node for this step in the lineage chain, and attach it to the previous one.
+ node = Node(
+ name=f"{scope_name}.{column}" if scope_name else str(column),
+ source=source,
+ expression=select,
+ source_name=source_name or "",
+ reference_node_name=reference_node_name or "",
+ )
- # if the select is a star add all scope sources as downstreams
- if select.is_star:
- for source in scope.sources.values():
- if isinstance(source, Scope):
- source = source.expression
- node.downstream.append(Node(name=select.sql(), source=source, expression=source))
+ if upstream:
+ upstream.downstream.append(node)
- # Find all columns that went into creating this one to list their lineage nodes.
- source_columns = set(find_all_in_scope(select, exp.Column))
+ subquery_scopes = {
+ id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
+ }
- # If the source is a UDTF find columns used in the UTDF to generate the table
- if isinstance(source, exp.UDTF):
- source_columns |= set(source.find_all(exp.Column))
+ for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
+ subquery_scope = subquery_scopes.get(id(subquery))
+ if not subquery_scope:
+ logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
+ continue
- for c in source_columns:
- table = c.table
- source = scope.sources.get(table)
+ for name in subquery.named_selects:
+ to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
+ # if the select is a star add all scope sources as downstreams
+ if select.is_star:
+ for source in scope.sources.values():
if isinstance(source, Scope):
- # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
- to_node(
- c.name,
- scope=source,
- scope_name=table,
- upstream=node,
- alias=aliases.get(table) or alias,
- )
- else:
- # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
- # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
- # is not passed into the `sources` map.
- source = source or exp.Placeholder()
- node.downstream.append(Node(name=c.sql(), source=source, expression=source))
-
- return node
+ source = source.expression
+ node.downstream.append(Node(name=select.sql(), source=source, expression=source))
+
+ # Find all columns that went into creating this one to list their lineage nodes.
+ source_columns = set(find_all_in_scope(select, exp.Column))
+
+ # If the source is a UDTF find columns used in the UTDF to generate the table
+ if isinstance(source, exp.UDTF):
+ source_columns |= set(source.find_all(exp.Column))
+
+ for c in source_columns:
+ table = c.table
+ source = scope.sources.get(table)
+
+ if isinstance(source, Scope):
+ selected_node, _ = scope.selected_sources.get(table, (None, None))
+ # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
+ to_node(
+ c.name,
+ scope=source,
+ dialect=dialect,
+ scope_name=table,
+ upstream=node,
+ source_name=source_names.get(table) or source_name,
+ reference_node_name=selected_node.name if selected_node else None,
+ )
+ else:
+ # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
+ # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
+ # is not passed into the `sources` map.
+ source = source or exp.Placeholder()
+ node.downstream.append(Node(name=c.sql(), source=source, expression=source))
- return to_node(column if isinstance(column, str) else column.name, scope)
+ return node
class GraphHTML:
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index ce274bb..81b1ee6 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateToDi,
exp.Floor,
exp.Levenshtein,
+ exp.Sign,
exp.StrPosition,
exp.TsOrDiToDi,
},
@@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
+ exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
@@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
+ exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
+ exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
- exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
}
NESTED_TYPES = {
@@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
- elif source and col.table in selects and col.name in selects[col.table]:
- self._set_type(col, selects[col.table][col.name].type)
+ elif source:
+ if col.table in selects and col.name in selects[col.table]:
+ self._set_type(col, selects[col.table][col.name].type)
+ elif isinstance(source.expression, exp.Unnest):
+ self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
last_datatype = None
for expr in expressions:
- last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
+ expr_type = expr.type
+
+ # Stop at the first nested data type found - we don't want to _maybe_coerce nested types
+ if expr_type.args.get("nested"):
+ last_datatype = expr_type
+ break
+
+ last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
@@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
+ def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
+ self._annotate_args(expression)
+ self._set_type(expression, None)
+ this_type = expression.this.type
+
+ if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
+ for e in this_type.expressions:
+ if e.name == expression.expression.name:
+ self._set_type(expression, e.kind)
+ break
+
+ return expression
+
def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
self._annotate_args(expression)
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
return expression
+
+ def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
+ self._annotate_args(expression)
+ child = seq_get(expression.expressions, 0)
+ self._set_type(expression, child and seq_get(child.type.expressions, 0))
+ return expression
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index d22a998..f2a0990 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -10,13 +10,11 @@ if t.TYPE_CHECKING:
@t.overload
-def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
- ...
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
- ...
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ef589c9..233ffc9 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
+ if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
+ continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
@@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
- (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
+ (
+ alias_expr.find(exp.AggFunc)
+ and (
+ column.find_ancestor(exp.AggFunc)
+ and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
+ )
+ )
if alias_expr
else False
)
@@ -404,7 +412,7 @@ def _expand_stars(
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
- elif expression.is_star:
+ elif expression.is_star and not isinstance(expression, exp.Dot):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
@@ -437,7 +445,7 @@ def _expand_stars(
if pivot_columns:
new_selections.extend(
- exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
+ alias(exp.column(name, table=pivot.alias), name, copy=False)
for name in pivot_columns
if name not in columns_to_exclude
)
@@ -466,7 +474,7 @@ def _expand_stars(
)
# Ensures we don't overwrite the initial selections with an empty list
- if new_selections:
+ if new_selections and isinstance(scope.expression, exp.Select):
scope.expression.set("expressions", new_selections)
@@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections.append(selection)
- scope.expression.set("expressions", new_selections)
+ if isinstance(scope.expression, exp.Select):
+ scope.expression.set("expressions", new_selections)
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
@@ -615,7 +624,7 @@ class Resolver:
node, _ = self.scope.selected_sources.get(table_name)
- if isinstance(node, exp.Subqueryable):
+ if isinstance(node, exp.Query):
while node and node.alias != table_name:
node = node.parent
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index d460e81..214ac0a 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -55,8 +55,8 @@ def qualify_tables(
if not table.args.get("catalog") and table.args.get("db"):
table.set("catalog", catalog)
- if not isinstance(expression, exp.Subqueryable):
- for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
+ if not isinstance(expression, exp.Query):
+ for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 0eae979..443fa6c 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -138,7 +138,7 @@ class Scope:
and _is_derived_table(node)
):
self._derived_tables.append(node)
- elif isinstance(node, exp.Subqueryable):
+ elif isinstance(node, exp.UNWRAPPED_QUERIES):
self._subqueries.append(node)
self._collected = True
@@ -225,7 +225,7 @@ class Scope:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
- list[exp.Subqueryable]: subqueries
+ list[exp.Select | exp.Union]: subqueries
"""
self._ensure_collected()
return self._subqueries
@@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
- if isinstance(expression, exp.Unionable) or (
- isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
+ if isinstance(expression, exp.Query) or (
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
):
return list(_traverse_scope(Scope(expression)))
@@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
- return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
+ return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
def _traverse_tables(scope):
@@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None):
and _is_derived_table(node)
)
or isinstance(node, exp.UDTF)
- or isinstance(node, exp.Subqueryable)
+ or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 9ffddb5..2e43d21 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str:
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
- exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
+ exp.Anonymous: lambda e: _anonymous(e),
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
@@ -1219,6 +1219,20 @@ GEN_MAP = {
}
+def _anonymous(e: exp.Anonymous) -> str:
+ this = e.this
+ if isinstance(this, str):
+ name = this.upper()
+ elif isinstance(this, exp.Identifier):
+ name = f'"{this.name}"' if this.quoted else this.name.upper()
+ else:
+ raise ValueError(
+ f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
+ )
+
+ return f"{name} {','.join(gen(e) for e in e.expressions)}"
+
+
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index b4c7475..36d9da4 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name):
else:
_replace(predicate, join_key_not_null)
+ group = select.args.get("group")
+
+ if group:
+ if {value.this} != set(group.expressions):
+ select = (
+ exp.select(exp.column(value.alias, "_q"))
+ .from_(select.subquery("_q", copy=False), copy=False)
+ .group_by(exp.column(value.alias, "_q"), copy=False)
+ )
+ else:
+ select = select.group_by(value.this, copy=False)
+
parent_select.join(
- select.group_by(value.this, copy=False),
+ select,
on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 4e7f870..49dac2e 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -17,6 +17,8 @@ if t.TYPE_CHECKING:
logger = logging.getLogger("sqlglot")
+OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]]
+
def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
if len(args) == 1 and args[0].is_star:
@@ -367,6 +369,7 @@ class Parser(metaclass=_Parser):
TokenType.TEMPORARY,
TokenType.TOP,
TokenType.TRUE,
+ TokenType.TRUNCATE,
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.UPDATE,
@@ -435,6 +438,7 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
+ TokenType.TRUNCATE,
TokenType.WINDOW,
TokenType.XOR,
*TYPE_TOKENS,
@@ -578,7 +582,7 @@ class Parser(metaclass=_Parser):
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
- exp.Expression: lambda self: self._parse_statement(),
+ exp.Expression: lambda self: self._parse_expression(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
exp.Having: lambda self: self._parse_having(),
@@ -625,10 +629,10 @@ class Parser(metaclass=_Parser):
TokenType.SET: lambda self: self._parse_set(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
+ TokenType.TRUNCATE: lambda self: self._parse_truncate_table(),
TokenType.USE: lambda self: self.expression(
exp.Use,
- kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"))
- and exp.var(self._prev.text),
+ kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False),
this=self._parse_table(schema=False),
),
}
@@ -642,36 +646,44 @@ class Parser(metaclass=_Parser):
TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()),
}
- PRIMARY_PARSERS = {
- TokenType.STRING: lambda self, token: self.expression(
- exp.Literal, this=token.text, is_string=True
- ),
- TokenType.NUMBER: lambda self, token: self.expression(
- exp.Literal, this=token.text, is_string=False
- ),
- TokenType.STAR: lambda self, _: self.expression(
- exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
+ STRING_PARSERS = {
+ TokenType.HEREDOC_STRING: lambda self, token: self.expression(
+ exp.RawString, this=token.text
),
- TokenType.NULL: lambda self, _: self.expression(exp.Null),
- TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
- TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
- TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
- TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
- TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
- TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.NATIONAL_STRING: lambda self, token: self.expression(
exp.National, this=token.text
),
TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
- TokenType.HEREDOC_STRING: lambda self, token: self.expression(
- exp.RawString, this=token.text
+ TokenType.STRING: lambda self, token: self.expression(
+ exp.Literal, this=token.text, is_string=True
),
TokenType.UNICODE_STRING: lambda self, token: self.expression(
exp.UnicodeString,
this=token.text,
escape=self._match_text_seq("UESCAPE") and self._parse_string(),
),
+ }
+
+ NUMERIC_PARSERS = {
+ TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
+ TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
+ TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
+ TokenType.NUMBER: lambda self, token: self.expression(
+ exp.Literal, this=token.text, is_string=False
+ ),
+ }
+
+ PRIMARY_PARSERS = {
+ **STRING_PARSERS,
+ **NUMERIC_PARSERS,
+ TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
+ TokenType.NULL: lambda self, _: self.expression(exp.Null),
+ TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
+ TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
+ TokenType.STAR: lambda self, _: self.expression(
+ exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
+ ),
}
PLACEHOLDER_PARSERS = {
@@ -799,7 +811,9 @@ class Parser(metaclass=_Parser):
exp.CharacterSetColumnConstraint, this=self._parse_var_or_string()
),
"CHECK": lambda self: self.expression(
- exp.CheckColumnConstraint, this=self._parse_wrapped(self._parse_conjunction)
+ exp.CheckColumnConstraint,
+ this=self._parse_wrapped(self._parse_conjunction),
+ enforced=self._match_text_seq("ENFORCED"),
),
"COLLATE": lambda self: self.expression(
exp.CollateColumnConstraint, this=self._parse_var()
@@ -873,6 +887,8 @@ class Parser(metaclass=_Parser):
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
+ KEY_VALUE_DEFINITIONS = (exp.Alias, exp.EQ, exp.PropertyEQ, exp.Slice)
+
FUNCTION_PARSERS = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
@@ -895,6 +911,7 @@ class Parser(metaclass=_Parser):
QUERY_MODIFIER_PARSERS = {
TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()),
+ TokenType.PREWHERE: lambda self: ("prewhere", self._parse_prewhere()),
TokenType.WHERE: lambda self: ("where", self._parse_where()),
TokenType.GROUP_BY: lambda self: ("group", self._parse_group()),
TokenType.HAVING: lambda self: ("having", self._parse_having()),
@@ -934,22 +951,23 @@ class Parser(metaclass=_Parser):
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
}
- MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
-
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE}
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
- TRANSACTION_CHARACTERISTICS = {
- "ISOLATION LEVEL REPEATABLE READ",
- "ISOLATION LEVEL READ COMMITTED",
- "ISOLATION LEVEL READ UNCOMMITTED",
- "ISOLATION LEVEL SERIALIZABLE",
- "READ WRITE",
- "READ ONLY",
+ TRANSACTION_CHARACTERISTICS: OPTIONS_TYPE = {
+ "ISOLATION": (
+ ("LEVEL", "REPEATABLE", "READ"),
+ ("LEVEL", "READ", "COMMITTED"),
+ ("LEVEL", "READ", "UNCOMITTED"),
+ ("LEVEL", "SERIALIZABLE"),
+ ),
+ "READ": ("WRITE", "ONLY"),
}
+ USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple())
+
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
CLONE_KEYWORDS = {"CLONE", "COPY"}
@@ -1012,6 +1030,9 @@ class Parser(metaclass=_Parser):
# If this is True and '(' is not found, the keyword will be treated as an identifier
VALUES_FOLLOWED_BY_PAREN = True
+ # Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift)
+ SUPPORTS_IMPLICIT_UNNEST = False
+
__slots__ = (
"error_level",
"error_message_context",
@@ -2450,10 +2471,37 @@ class Parser(metaclass=_Parser):
alias=self._parse_table_alias() if parse_alias else None,
)
+ def _implicit_unnests_to_explicit(self, this: E) -> E:
+ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm
+
+ refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name}
+ for i, join in enumerate(this.args.get("joins") or []):
+ table = join.this
+ normalized_table = table.copy()
+ normalized_table.meta["maybe_column"] = True
+ normalized_table = _norm(normalized_table, dialect=self.dialect)
+
+ if isinstance(table, exp.Table) and not join.args.get("on"):
+ if normalized_table.parts[0].name in refs:
+ table_as_column = table.to_column()
+ unnest = exp.Unnest(expressions=[table_as_column])
+
+ # Table.to_column creates a parent Alias node that we want to convert to
+ # a TableAlias and attach to the Unnest, so it matches the parser's output
+ if isinstance(table.args.get("alias"), exp.TableAlias):
+ table_as_column.replace(table_as_column.this)
+ exp.alias_(unnest, None, table=[table.args["alias"].this], copy=False)
+
+ table.replace(unnest)
+
+ refs.add(normalized_table.alias_or_name)
+
+ return this
+
def _parse_query_modifiers(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
- if isinstance(this, self.MODIFIABLES):
+ if isinstance(this, (exp.Query, exp.Table)):
for join in iter(self._parse_join, None):
this.append("joins", join)
for lateral in iter(self._parse_lateral, None):
@@ -2478,6 +2526,10 @@ class Parser(metaclass=_Parser):
offset.set("expressions", limit_by_expressions)
continue
break
+
+ if self.SUPPORTS_IMPLICIT_UNNEST and this and "from" in this.args:
+ this = self._implicit_unnests_to_explicit(this)
+
return this
def _parse_hint(self) -> t.Optional[exp.Hint]:
@@ -2803,7 +2855,9 @@ class Parser(metaclass=_Parser):
or self._parse_placeholder()
)
- def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
+ def _parse_table_parts(
+ self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
+ ) -> exp.Table:
catalog = None
db = None
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
@@ -2817,8 +2871,20 @@ class Parser(metaclass=_Parser):
else:
catalog = db
db = table
+ # "" used for tsql FROM a..b case
table = self._parse_table_part(schema=schema) or ""
+ if (
+ wildcard
+ and self._is_connected()
+ and (isinstance(table, exp.Identifier) or not table)
+ and self._match(TokenType.STAR)
+ ):
+ if isinstance(table, exp.Identifier):
+ table.args["this"] += "*"
+ else:
+ table = exp.Identifier(this="*")
+
if is_db_reference:
catalog = db
db = table
@@ -2861,6 +2927,9 @@ class Parser(metaclass=_Parser):
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
+
+ only = self._match(TokenType.ONLY)
+
this = t.cast(
exp.Expression,
bracket
@@ -2869,6 +2938,12 @@ class Parser(metaclass=_Parser):
),
)
+ if only:
+ this.set("only", only)
+
+ # Postgres supports a wildcard (table) suffix operator, which is a no-op in this context
+ self._match_text_seq("*")
+
if schema:
return self._parse_schema(this=this)
@@ -3161,6 +3236,14 @@ class Parser(metaclass=_Parser):
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
return [agg.alias for agg in aggregations]
+ def _parse_prewhere(self, skip_where_token: bool = False) -> t.Optional[exp.PreWhere]:
+ if not skip_where_token and not self._match(TokenType.PREWHERE):
+ return None
+
+ return self.expression(
+ exp.PreWhere, comments=self._prev_comments, this=self._parse_conjunction()
+ )
+
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
@@ -3291,8 +3374,12 @@ class Parser(metaclass=_Parser):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
- def _parse_ordered(self, parse_method: t.Optional[t.Callable] = None) -> exp.Ordered:
+ def _parse_ordered(
+ self, parse_method: t.Optional[t.Callable] = None
+ ) -> t.Optional[exp.Ordered]:
this = parse_method() if parse_method else self._parse_conjunction()
+ if not this:
+ return None
asc = self._match(TokenType.ASC)
desc = self._match(TokenType.DESC) or (asc and False)
@@ -3510,7 +3597,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("DISTINCT", "FROM"):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
- return self.expression(klass, this=this, expression=self._parse_conjunction())
+ return self.expression(klass, this=this, expression=self._parse_bitwise())
expression = self._parse_null() or self._parse_boolean()
if not expression:
@@ -3528,7 +3615,7 @@ class Parser(metaclass=_Parser):
matched_l_paren = self._prev.token_type == TokenType.L_PAREN
expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
- if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
+ if len(expressions) == 1 and isinstance(expressions[0], exp.Query):
this = self.expression(exp.In, this=this, query=expressions[0])
else:
this = self.expression(exp.In, this=this, expressions=expressions)
@@ -3959,7 +4046,7 @@ class Parser(metaclass=_Parser):
this = self._parse_query_modifiers(seq_get(expressions, 0))
- if isinstance(this, exp.Subqueryable):
+ if isinstance(this, exp.UNWRAPPED_QUERIES):
this = self._parse_set_operations(
self._parse_subquery(this=this, parse_alias=False)
)
@@ -4064,6 +4151,9 @@ class Parser(metaclass=_Parser):
alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
+ if alias:
+ args = self._kv_to_prop_eq(args)
+
if function and not anonymous:
if "dialect" in function.__code__.co_varnames:
func = function(args, dialect=self.dialect)
@@ -4076,6 +4166,8 @@ class Parser(metaclass=_Parser):
this = func
else:
+ if token_type == TokenType.IDENTIFIER:
+ this = exp.Identifier(this=this, quoted=True)
this = self.expression(exp.Anonymous, this=this, expressions=args)
if isinstance(this, exp.Expression):
@@ -4084,6 +4176,26 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
+ def _kv_to_prop_eq(self, expressions: t.List[exp.Expression]) -> t.List[exp.Expression]:
+ transformed = []
+
+ for e in expressions:
+ if isinstance(e, self.KEY_VALUE_DEFINITIONS):
+ if isinstance(e, exp.Alias):
+ e = self.expression(exp.PropertyEQ, this=e.args.get("alias"), expression=e.this)
+
+ if not isinstance(e, exp.PropertyEQ):
+ e = self.expression(
+ exp.PropertyEQ, this=exp.to_identifier(e.name), expression=e.expression
+ )
+
+ if isinstance(e.this, exp.Column):
+ e.this.replace(e.this.this)
+
+ transformed.append(e)
+
+ return transformed
+
def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
return self._parse_column_def(self._parse_id_var())
@@ -4496,7 +4608,7 @@ class Parser(metaclass=_Parser):
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
if bracket_kind == TokenType.L_BRACE:
- this = self.expression(exp.Struct, expressions=expressions)
+ this = self.expression(exp.Struct, expressions=self._kv_to_prop_eq(expressions))
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
@@ -4747,12 +4859,10 @@ class Parser(metaclass=_Parser):
return None
@t.overload
- def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
- ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
- def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
- ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@@ -5140,16 +5250,16 @@ class Parser(metaclass=_Parser):
return None
def _parse_string(self) -> t.Optional[exp.Expression]:
- if self._match_set((TokenType.STRING, TokenType.RAW_STRING)):
- return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
+ if self._match_set(self.STRING_PARSERS):
+ return self.STRING_PARSERS[self._prev.token_type](self, self._prev)
return self._parse_placeholder()
def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
def _parse_number(self) -> t.Optional[exp.Expression]:
- if self._match(TokenType.NUMBER):
- return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
+ if self._match_set(self.NUMERIC_PARSERS):
+ return self.NUMERIC_PARSERS[self._prev.token_type](self, self._prev)
return self._parse_placeholder()
def _parse_identifier(self) -> t.Optional[exp.Expression]:
@@ -5182,6 +5292,9 @@ class Parser(metaclass=_Parser):
def _parse_var_or_string(self) -> t.Optional[exp.Expression]:
return self._parse_var() or self._parse_string()
+ def _parse_primary_or_var(self) -> t.Optional[exp.Expression]:
+ return self._parse_primary() or self._parse_var(any_token=True)
+
def _parse_null(self) -> t.Optional[exp.Expression]:
if self._match_set(self.NULL_TOKENS):
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
@@ -5200,16 +5313,12 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _parse_parameter(self) -> exp.Parameter:
- def _parse_parameter_part() -> t.Optional[exp.Expression]:
- return (
- self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True)
- )
-
self._match(TokenType.L_BRACE)
- this = _parse_parameter_part()
- expression = self._match(TokenType.COLON) and _parse_parameter_part()
+ this = self._parse_identifier() or self._parse_primary_or_var()
+ expression = self._match(TokenType.COLON) and (
+ self._parse_identifier() or self._parse_primary_or_var()
+ )
self._match(TokenType.R_BRACE)
-
return self.expression(exp.Parameter, this=this, expression=expression)
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
@@ -5376,35 +5485,15 @@ class Parser(metaclass=_Parser):
exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists
)
- def _parse_add_constraint(self) -> exp.AddConstraint:
- this = None
- kind = self._prev.token_type
-
- if kind == TokenType.CONSTRAINT:
- this = self._parse_id_var()
-
- if self._match_text_seq("CHECK"):
- expression = self._parse_wrapped(self._parse_conjunction)
- enforced = self._match_text_seq("ENFORCED") or False
-
- return self.expression(
- exp.AddConstraint, this=this, expression=expression, enforced=enforced
- )
-
- if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY):
- expression = self._parse_foreign_key()
- elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
- expression = self._parse_primary_key()
- else:
- expression = None
-
- return self.expression(exp.AddConstraint, this=this, expression=expression)
-
def _parse_alter_table_add(self) -> t.List[exp.Expression]:
index = self._index - 1
- if self._match_set(self.ADD_CONSTRAINT_TOKENS):
- return self._parse_csv(self._parse_add_constraint)
+ if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False):
+ return self._parse_csv(
+ lambda: self.expression(
+ exp.AddConstraint, expressions=self._parse_csv(self._parse_constraint)
+ )
+ )
self._retreat(index)
if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"):
@@ -5472,6 +5561,7 @@ class Parser(metaclass=_Parser):
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
if parser:
actions = ensure_list(parser(self))
+ options = self._parse_csv(self._parse_property)
if not self._curr and actions:
return self.expression(
@@ -5480,6 +5570,7 @@ class Parser(metaclass=_Parser):
exists=exists,
actions=actions,
only=only,
+ options=options,
)
return self._parse_as_command(start)
@@ -5610,11 +5701,34 @@ class Parser(metaclass=_Parser):
return set_
- def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]:
- for option in options:
- if self._match_text_seq(*option.split(" ")):
- return exp.var(option)
- return None
+ def _parse_var_from_options(
+ self, options: OPTIONS_TYPE, raise_unmatched: bool = True
+ ) -> t.Optional[exp.Var]:
+ start = self._curr
+ if not start:
+ return None
+
+ option = start.text.upper()
+ continuations = options.get(option)
+
+ index = self._index
+ self._advance()
+ for keywords in continuations or []:
+ if isinstance(keywords, str):
+ keywords = (keywords,)
+
+ if self._match_text_seq(*keywords):
+ option = f"{option} {' '.join(keywords)}"
+ break
+ else:
+ if continuations or continuations is None:
+ if raise_unmatched:
+ self.raise_error(f"Unknown option {option}")
+
+ self._retreat(index)
+ return None
+
+ return exp.var(option)
def _parse_as_command(self, start: Token) -> exp.Command:
while self._curr:
@@ -5806,14 +5920,12 @@ class Parser(metaclass=_Parser):
return True
@t.overload
- def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
- ...
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]:
- ...
+ ) -> t.Optional[exp.Expression]: ...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):
@@ -5849,3 +5961,53 @@ class Parser(metaclass=_Parser):
else:
column.replace(dot_or_id)
return node
+
+ def _parse_truncate_table(self) -> t.Optional[exp.TruncateTable] | exp.Expression:
+ start = self._prev
+
+ # Not to be confused with TRUNCATE(number, decimals) function call
+ if self._match(TokenType.L_PAREN):
+ self._retreat(self._index - 2)
+ return self._parse_function()
+
+ # Clickhouse supports TRUNCATE DATABASE as well
+ is_database = self._match(TokenType.DATABASE)
+
+ self._match(TokenType.TABLE)
+
+ exists = self._parse_exists(not_=False)
+
+ expressions = self._parse_csv(
+ lambda: self._parse_table(schema=True, is_db_reference=is_database)
+ )
+
+ cluster = self._parse_on_property() if self._match(TokenType.ON) else None
+
+ if self._match_text_seq("RESTART", "IDENTITY"):
+ identity = "RESTART"
+ elif self._match_text_seq("CONTINUE", "IDENTITY"):
+ identity = "CONTINUE"
+ else:
+ identity = None
+
+ if self._match_text_seq("CASCADE") or self._match_text_seq("RESTRICT"):
+ option = self._prev.text
+ else:
+ option = None
+
+ partition = self._parse_partition()
+
+ # Fallback case
+ if self._curr:
+ return self._parse_as_command(start)
+
+ return self.expression(
+ exp.TruncateTable,
+ expressions=expressions,
+ is_database=is_database,
+ exists=exists,
+ cluster=cluster,
+ identity=identity,
+ option=option,
+ partition=partition,
+ )
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 939ca18..da9df7d 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -302,6 +302,7 @@ class TokenType(AutoName):
OBJECT_IDENTIFIER = auto()
OFFSET = auto()
ON = auto()
+ ONLY = auto()
OPERATOR = auto()
ORDER_BY = auto()
ORDER_SIBLINGS_BY = auto()
@@ -317,6 +318,7 @@ class TokenType(AutoName):
PIVOT = auto()
PLACEHOLDER = auto()
PRAGMA = auto()
+ PREWHERE = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
@@ -353,6 +355,7 @@ class TokenType(AutoName):
TOP = auto()
THEN = auto()
TRUE = auto()
+ TRUNCATE = auto()
UNCACHE = auto()
UNION = auto()
UNNEST = auto()
@@ -370,6 +373,7 @@ class TokenType(AutoName):
UNIQUE = auto()
VERSION_SNAPSHOT = auto()
TIMESTAMP_SNAPSHOT = auto()
+ OPTION = auto()
_ALL_TOKEN_TYPES = list(TokenType)
@@ -657,6 +661,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
"END": TokenType.END,
+ "ENUM": TokenType.ENUM,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
@@ -752,6 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
+ "TRUNCATE": TokenType.TRUNCATE,
"UNION": TokenType.UNION,
"UNKNOWN": TokenType.UNKNOWN,
"UNNEST": TokenType.UNNEST,
@@ -860,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
"GRANT": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
- "TRUNCATE": TokenType.COMMAND,
"VACUUM": TokenType.COMMAND,
"USER-DEFINED": TokenType.USERDEFINED,
"FOR VERSION": TokenType.VERSION_SNAPSHOT,
@@ -1036,12 +1041,6 @@ class Tokenizer(metaclass=_Tokenizer):
def _text(self) -> str:
return self.sql[self._start : self._current]
- def peek(self, i: int = 0) -> str:
- i = self._current + i
- if i < self.size:
- return self.sql[i]
- return ""
-
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
@@ -1182,12 +1181,8 @@ class Tokenizer(metaclass=_Tokenizer):
if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
- after = self.peek(1)
- if after.isdigit() or not after.isalpha():
- decimal = True
- self._advance()
- else:
- return self._add(TokenType.VAR)
+ decimal = True
+ self._advance()
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 4777609..04c1f7b 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -547,7 +547,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
prop
and prop.this
and isinstance(prop.this, exp.Schema)
- and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions)
+ and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
):
prop_this = exp.Tuple(
expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
@@ -560,6 +560,22 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
return expression
+def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
+ """
+ Convert struct arguments to aliases: STRUCT(1 AS y) .
+ """
+ if isinstance(expression, exp.Struct):
+ expression.set(
+ "expressions",
+ [
+ exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
+ for e in expression.expressions
+ ],
+ )
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: