summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:11:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:12:02 +0000
commit8d36f5966675e23bee7026ba37ae0647fbf47300 (patch)
treedf4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot
parentReleasing debian version 22.2.0-1. (diff)
downloadsqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz
sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py17
-rw-r--r--sqlglot/dataframe/sql/dataframe.py13
-rw-r--r--sqlglot/dataframe/sql/functions.py14
-rw-r--r--sqlglot/dataframe/sql/session.py11
-rw-r--r--sqlglot/dialects/__init__.py2
-rw-r--r--sqlglot/dialects/athena.py37
-rw-r--r--sqlglot/dialects/bigquery.py54
-rw-r--r--sqlglot/dialects/clickhouse.py77
-rw-r--r--sqlglot/dialects/dialect.py115
-rw-r--r--sqlglot/dialects/doris.py14
-rw-r--r--sqlglot/dialects/drill.py16
-rw-r--r--sqlglot/dialects/duckdb.py42
-rw-r--r--sqlglot/dialects/hive.py7
-rw-r--r--sqlglot/dialects/mysql.py82
-rw-r--r--sqlglot/dialects/oracle.py2
-rw-r--r--sqlglot/dialects/postgres.py12
-rw-r--r--sqlglot/dialects/presto.py40
-rw-r--r--sqlglot/dialects/prql.py109
-rw-r--r--sqlglot/dialects/redshift.py22
-rw-r--r--sqlglot/dialects/snowflake.py201
-rw-r--r--sqlglot/dialects/spark.py11
-rw-r--r--sqlglot/dialects/spark2.py9
-rw-r--r--sqlglot/dialects/sqlite.py12
-rw-r--r--sqlglot/dialects/starrocks.py7
-rw-r--r--sqlglot/dialects/tableau.py2
-rw-r--r--sqlglot/dialects/teradata.py9
-rw-r--r--sqlglot/dialects/trino.py1
-rw-r--r--sqlglot/dialects/tsql.py24
-rw-r--r--sqlglot/diff.py16
-rw-r--r--sqlglot/executor/__init__.py15
-rw-r--r--sqlglot/executor/env.py14
-rw-r--r--sqlglot/executor/python.py4
-rw-r--r--sqlglot/expressions.py533
-rw-r--r--sqlglot/generator.py549
-rw-r--r--sqlglot/helper.py16
-rw-r--r--sqlglot/lineage.py52
-rw-r--r--sqlglot/optimizer/annotate_types.py127
-rw-r--r--sqlglot/optimizer/canonicalize.py35
-rw-r--r--sqlglot/optimizer/eliminate_ctes.py2
-rw-r--r--sqlglot/optimizer/merge_subqueries.py1
-rw-r--r--sqlglot/optimizer/normalize.py2
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py8
-rw-r--r--sqlglot/optimizer/optimizer.py6
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py8
-rw-r--r--sqlglot/optimizer/pushdown_projections.py14
-rw-r--r--sqlglot/optimizer/qualify_columns.py10
-rw-r--r--sqlglot/optimizer/qualify_tables.py6
-rw-r--r--sqlglot/optimizer/scope.py165
-rw-r--r--sqlglot/optimizer/simplify.py472
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
-rw-r--r--sqlglot/parser.py493
-rw-r--r--sqlglot/planner.py12
-rw-r--r--sqlglot/tokens.py87
-rw-r--r--sqlglot/transforms.py8
54 files changed, 2549 insertions, 1070 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index e30232c..756532f 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -45,7 +45,7 @@ from sqlglot.expressions import (
from sqlglot.generator import Generator as Generator
from sqlglot.parser import Parser as Parser
from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
-from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
+from sqlglot.tokens import Token as Token, Tokenizer as Tokenizer, TokenType as TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
@@ -69,6 +69,21 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
+def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]:
+ """
+ Tokenizes the given SQL string.
+
+ Args:
+ sql: the SQL code string to tokenize.
+ read: the SQL dialect to apply during tokenizing (eg. "spark", "hive", "presto", "mysql").
+ dialect: the SQL dialect (alias for read).
+
+ Returns:
+ The resulting list of tokens.
+ """
+ return Dialect.get_or_raise(read or dialect).tokenize(sql)
+
+
def parse(
sql: str, read: DialectType = None, dialect: DialectType = None, **opts
) -> t.List[t.Optional[Expression]]:
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 0bacbf9..8316c36 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -18,8 +18,6 @@ from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
-from sqlglot.optimizer import optimize as optimize_func
-from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@@ -121,7 +119,9 @@ class DataFrame:
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
)
replacement_mapping[old_name_id] = new_hashed_id
- expression = expression.transform(replace_id_value, replacement_mapping)
+ expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
+ exp.Select
+ )
return expression
def _create_cte_from_expression(
@@ -306,11 +306,12 @@ class DataFrame:
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
for expression_type, select_expression in select_expressions:
- select_expression = select_expression.transform(replace_id_value, replacement_mapping)
+ select_expression = select_expression.transform(
+ replace_id_value, replacement_mapping
+ ).assert_is(exp.Select)
if optimize:
- quote_identifiers(select_expression, dialect=dialect)
select_expression = t.cast(
- exp.Select, optimize_func(select_expression, dialect=dialect)
+ exp.Select, self.spark._optimize(select_expression, dialect=dialect)
)
select_expression = df._replace_cte_names_with_hashes(select_expression)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index db5201f..b4dd2c6 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -184,7 +184,7 @@ def floor(col: ColumnOrName) -> Column:
def log10(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Log10)
+ return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col)
def log1p(col: ColumnOrName) -> Column:
@@ -192,7 +192,7 @@ def log1p(col: ColumnOrName) -> Column:
def log2(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Log2)
+ return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col)
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
@@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column:
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "CORR", col2)
+ return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2)
def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
+ return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2)
def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
+ return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
@@ -971,10 +971,10 @@ def array_join(
) -> Column:
if null_replacement is not None:
return Column.invoke_expression_over_column(
- col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
+ col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
)
return Column.invoke_expression_over_column(
- col, expression.ArrayJoin, expression=lit(delimiter)
+ col, expression.ArrayToString, expression=lit(delimiter)
)
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index bfc022b..4e47aaa 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -12,6 +12,8 @@ from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
from sqlglot.helper import classproperty
+from sqlglot.optimizer import optimize
+from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
@@ -104,8 +106,15 @@ class SparkSession:
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
+ def _optimize(
+ self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
+ ) -> exp.Expression:
+ dialect = dialect or self.dialect
+ quote_identifiers(expression, dialect=dialect)
+ return optimize(expression, dialect=dialect)
+
def sql(self, sqlQuery: str) -> DataFrame:
- expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
+ expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 82552c9..29c6580 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -61,6 +61,7 @@ dialect implementations in order to understand how their various components can
----
"""
+from sqlglot.dialects.athena import Athena
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
@@ -73,6 +74,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
+from sqlglot.dialects.prql import PRQL
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py
new file mode 100644
index 0000000..f2deec8
--- /dev/null
+++ b/sqlglot/dialects/athena.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from sqlglot import exp
+from sqlglot.dialects.trino import Trino
+from sqlglot.tokens import TokenType
+
+
+class Athena(Trino):
+ class Parser(Trino.Parser):
+ STATEMENT_PARSERS = {
+ **Trino.Parser.STATEMENT_PARSERS,
+ TokenType.USING: lambda self: self._parse_as_command(self._prev),
+ }
+
+ class Generator(Trino.Generator):
+ PROPERTIES_LOCATION = {
+ **Trino.Generator.PROPERTIES_LOCATION,
+ exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
+ }
+
+ TYPE_MAPPING = {
+ **Trino.Generator.TYPE_MAPPING,
+ exp.DataType.Type.TEXT: "STRING",
+ }
+
+ TRANSFORMS = {
+ **Trino.Generator.TRANSFORMS,
+ exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
+ }
+
+ def property_sql(self, expression: exp.Property) -> str:
+ return (
+ f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
+ )
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5bfc3ea..2167ba2 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -24,6 +24,7 @@ from sqlglot.dialects.dialect import (
rename_func,
timestrtotime_sql,
ts_or_ds_add_cast,
+ unit_to_var,
)
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
@@ -41,14 +42,22 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
structs = []
alias = expression.args.get("alias")
for tup in expression.find_all(exp.Tuple):
- field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
+ field_aliases = (
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(tup.expressions)))
+ )
expressions = [
exp.PropertyEQ(this=exp.to_identifier(name), expression=fld)
for name, fld in zip(field_aliases, tup.expressions)
]
structs.append(exp.Struct(expressions=expressions))
- return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
+ # Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression
+ alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None
+ return self.unnest_sql(
+ exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only)
+ )
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
@@ -190,7 +199,7 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
- unit = expression.args.get("unit") or "DAY"
+ unit = unit_to_var(expression)
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
@@ -238,16 +247,6 @@ class BigQuery(Dialect):
"%E6S": "%S.%f",
}
- ESCAPE_SEQUENCES = {
- "\\a": "\a",
- "\\b": "\b",
- "\\f": "\f",
- "\\n": "\n",
- "\\r": "\r",
- "\\t": "\t",
- "\\v": "\v",
- }
-
FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
@@ -315,6 +314,7 @@ class BigQuery(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
+ "DATETIME": TokenType.TIMESTAMP,
"DECLARE": TokenType.COMMAND,
"ELSEIF": TokenType.COMMAND,
"EXCEPTION": TokenType.COMMAND,
@@ -486,14 +486,14 @@ class BigQuery(Dialect):
table.set("db", exp.Identifier(this=parts[0]))
table.set("this", exp.Identifier(this=parts[1]))
- if isinstance(table.this, exp.Identifier) and "." in table.name:
+ if any("." in p.name for p in table.parts):
catalog, db, this, *rest = (
- t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True))
- for x in split_num_words(table.name, ".", 3)
+ exp.to_identifier(p, quoted=True)
+ for p in split_num_words(".".join(p.name for p in table.parts), ".", 3)
)
if rest and this:
- this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
+ this = exp.Dot.build([this, *rest]) # type: ignore
table = exp.Table(this=this, db=db, catalog=catalog)
table.meta["quoted_table"] = True
@@ -527,7 +527,9 @@ class BigQuery(Dialect):
return json_object
- def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ def _parse_bracket(
+ self, this: t.Optional[exp.Expression] = None
+ ) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)
if this is bracket:
@@ -566,6 +568,7 @@ class BigQuery(Dialect):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
CAN_IMPLEMENT_ARRAY_ANY = True
+ SUPPORTS_TO_NUMBER = False
NAMED_PLACEHOLDER_TOKEN = "@"
TRANSFORMS = {
@@ -588,7 +591,7 @@ class BigQuery(Dialect):
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", e.this, e.expression, e.unit or "DAY"
+ "DATE_DIFF", e.this, e.expression, unit_to_var(e)
),
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
@@ -607,6 +610,7 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Max: max_or_greatest,
+ exp.Mod: rename_func("MOD"),
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
@@ -847,10 +851,10 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
- this = self.sql(expression, "this")
+ this = expression.this
expressions = expression.expressions
- if len(expressions) == 1:
+ if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT):
arg = expressions[0]
if arg.type is None:
from sqlglot.optimizer.annotate_types import annotate_types
@@ -858,10 +862,10 @@ class BigQuery(Dialect):
arg = annotate_types(arg)
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
- # BQ doesn't support bracket syntax with string values
- return f"{this}.{arg.name}"
+ # BQ doesn't support bracket syntax with string values for structs
+ return f"{self.sql(this)}.{arg.name}"
- expressions_sql = ", ".join(self.sql(e) for e in expressions)
+ expressions_sql = self.expressions(expression, flat=True)
offset = expression.args.get("offset")
if offset == 0:
@@ -874,7 +878,7 @@ class BigQuery(Dialect):
if expression.args.get("safe"):
expressions_sql = f"SAFE_{expressions_sql}"
- return f"{this}[{expressions_sql}]"
+ return f"{self.sql(this)}[{expressions_sql}]"
def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 90167f6..631dc30 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -15,7 +15,6 @@ from sqlglot.dialects.dialect import (
rename_func,
var_map_sql,
)
-from sqlglot.errors import ParseError
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
@@ -49,8 +48,9 @@ class ClickHouse(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
SAFE_DIVISION = True
+ LOG_BASE_FIRST: t.Optional[bool] = None
- ESCAPE_SEQUENCES = {
+ UNESCAPED_SEQUENCES = {
"\\0": "\0",
}
@@ -105,6 +105,7 @@ class ClickHouse(Dialect):
# * select x from t1 union all select x from t2 limit 1;
# * select x from t1 union all (select x from t2 limit 1);
MODIFIERS_ATTACHED_TO_UNION = False
+ INTERVAL_SPANS = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -260,6 +261,11 @@ class ClickHouse(Dialect):
"ArgMax",
]
+ FUNC_TOKENS = {
+ *parser.Parser.FUNC_TOKENS,
+ TokenType.SET,
+ }
+
AGG_FUNC_MAPPING = (
lambda functions, suffixes: {
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
@@ -305,6 +311,10 @@ class ClickHouse(Dialect):
TokenType.SETTINGS,
}
+ ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - {
+ TokenType.FORMAT,
+ }
+
LOG_DEFAULTS_TO_LN = True
QUERY_MODIFIER_PARSERS = {
@@ -316,6 +326,17 @@ class ClickHouse(Dialect):
TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()),
}
+ CONSTRAINT_PARSERS = {
+ **parser.Parser.CONSTRAINT_PARSERS,
+ "INDEX": lambda self: self._parse_index_constraint(),
+ "CODEC": lambda self: self._parse_compress(),
+ }
+
+ SCHEMA_UNNAMED_CONSTRAINTS = {
+ *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
+ "INDEX",
+ }
+
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
this = super()._parse_conjunction()
@@ -381,21 +402,20 @@ class ClickHouse(Dialect):
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.CTE:
- index = self._index
- try:
- # WITH <identifier> AS <subquery expression>
- return super()._parse_cte()
- except ParseError:
- # WITH <expression> AS <identifier>
- self._retreat(index)
+ # WITH <identifier> AS <subquery expression>
+ cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte)
- return self.expression(
+ if not cte:
+ # WITH <expression> AS <identifier>
+ cte = self.expression(
exp.CTE,
this=self._parse_conjunction(),
alias=self._parse_table_alias(),
scalar=True,
)
+ return cte
+
def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
@@ -508,6 +528,27 @@ class ClickHouse(Dialect):
self._retreat(index)
return None
+ def _parse_index_constraint(
+ self, kind: t.Optional[str] = None
+ ) -> exp.IndexColumnConstraint:
+ # INDEX name1 expr TYPE type1(args) GRANULARITY value
+ this = self._parse_id_var()
+ expression = self._parse_conjunction()
+
+ index_type = self._match_text_seq("TYPE") and (
+ self._parse_function() or self._parse_var()
+ )
+
+ granularity = self._match_text_seq("GRANULARITY") and self._parse_term()
+
+ return self.expression(
+ exp.IndexColumnConstraint,
+ this=this,
+ expression=expression,
+ index_type=index_type,
+ granularity=granularity,
+ )
+
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
@@ -517,6 +558,7 @@ class ClickHouse(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
+ SUPPORTS_TO_NUMBER = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -585,6 +627,9 @@ class ClickHouse(Dialect):
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
exp.CountIf: rename_func("countIf"),
+ exp.CompressColumnConstraint: lambda self,
+ e: f"CODEC({self.expressions(e, key='this', flat=True)})",
+ exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}",
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: date_delta_sql("DATE_ADD"),
exp.DateDiff: date_delta_sql("DATE_DIFF"),
@@ -737,3 +782,15 @@ class ClickHouse(Dialect):
def prewhere_sql(self, expression: exp.PreWhere) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('PREWHERE')}{self.sep()}{this}"
+
+ def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ expr = self.sql(expression, "expression")
+ expr = f" {expr}" if expr else ""
+ index_type = self.sql(expression, "index_type")
+ index_type = f" TYPE {index_type}" if index_type else ""
+ granularity = self.sql(expression, "granularity")
+ granularity = f" GRANULARITY {granularity}" if granularity else ""
+
+ return f"INDEX{this}{expr}{index_type}{granularity}"
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 599505c..81057c2 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -31,6 +31,7 @@ class Dialects(str, Enum):
DIALECT = ""
+ ATHENA = "athena"
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
@@ -42,6 +43,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
+ PRQL = "prql"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
@@ -108,11 +110,18 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
- klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
+ base = seq_get(bases, 0)
+ base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
+ base_parser = (getattr(base, "parser_class", Parser),)
+ base_generator = (getattr(base, "generator_class", Generator),)
- klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
- klass.parser_class = getattr(klass, "Parser", Parser)
- klass.generator_class = getattr(klass, "Generator", Generator)
+ klass.tokenizer_class = klass.__dict__.get(
+ "Tokenizer", type("Tokenizer", base_tokenizer, {})
+ )
+ klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
+ klass.generator_class = klass.__dict__.get(
+ "Generator", type("Generator", base_generator, {})
+ )
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
@@ -134,9 +143,31 @@ class _Dialect(type):
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
+ if "\\" in klass.tokenizer_class.STRING_ESCAPES:
+ klass.UNESCAPED_SEQUENCES = {
+ "\\a": "\a",
+ "\\b": "\b",
+ "\\f": "\f",
+ "\\n": "\n",
+ "\\r": "\r",
+ "\\t": "\t",
+ "\\v": "\v",
+ "\\\\": "\\",
+ **klass.UNESCAPED_SEQUENCES,
+ }
+
+ klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
+
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
+ if enum not in ("", "databricks", "hive", "spark", "spark2"):
+ modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
+ for modifier in ("cluster", "distribute", "sort"):
+ modifier_transforms.pop(modifier, None)
+
+ klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
+
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
False: Disables function name normalization.
"""
- LOG_BASE_FIRST = True
- """Whether the base comes first in the `LOG` function."""
+ LOG_BASE_FIRST: t.Optional[bool] = True
+ """
+ Whether the base comes first in the `LOG` function.
+ Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
+ """
NULL_ORDERING = "nulls_are_small"
"""
@@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
"""
- ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- """Mapping of an unescaped escape sequence to the corresponding character."""
+ UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
+ """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
PSEUDOCOLUMNS: t.Set[str] = set()
"""
@@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
- INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+ ESCAPED_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for string literals and identifiers
QUOTE_START = "'"
@@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
return ""
-def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
+def str_position_sql(
+ self: Generator, expression: exp.StrPosition, generate_instance: bool = False
+) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
+ instance = expression.args.get("instance") if generate_instance else None
+ position_offset = ""
+
if position:
- return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
- return f"STRPOS({this}, {substr})"
+ # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
+ this = self.func("SUBSTR", this, position)
+ position_offset = f" + {position} - 1"
+
+ return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
@@ -689,9 +731,7 @@ def build_date_delta_with_interval(
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
- return expression_class(
- this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
- )
+ return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
return _builder
@@ -710,18 +750,14 @@ def date_add_interval_sql(
) -> t.Callable[[Generator, exp.Expression], str]:
def func(self: Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
- unit = expression.args.get("unit")
- unit = exp.var(unit.name.upper() if unit else "DAY")
- interval = exp.Interval(this=expression.expression, unit=unit)
+ interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
- return self.func(
- "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
- )
+ return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
@@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return self.func(
name,
- exp.var(expression.text("unit").upper() or "DAY"),
+ unit_to_var(expression),
expression.expression,
expression.this,
)
@@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
+def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
+ unit = expression.args.get("unit")
+
+ if isinstance(unit, exp.Placeholder):
+ return unit
+ if unit:
+ return exp.Literal.string(unit.name)
+ return exp.Literal.string(default) if default else None
+
+
+def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
+ unit = expression.args.get("unit")
+
+ if isinstance(unit, (exp.Var, exp.Placeholder)):
+ return unit
+ return exp.Var(this=default) if default else None
+
+
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
def build_json_extract_path(
- expr_type: t.Type[F], zero_based_indexing: bool = True
+ expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
) -> t.Callable[[t.List], F]:
def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
@@ -1018,7 +1072,11 @@ def build_json_extract_path(
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
- return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
+ return expr_type(
+ this=seq_get(args, 0),
+ expression=exp.JSONPath(expressions=segments),
+ only_json_types=arrow_req_json_type,
+ )
return _builder
@@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
unnest = exp.Unnest(expressions=[expression.this])
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
return self.sql(exp.Array(expressions=[filtered]))
+
+
+def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
+ return self.func(
+ "TO_NUMBER",
+ expression.this,
+ expression.args.get("format"),
+ expression.args.get("nlsparam"),
+ )
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 9a84848..f4ec0e5 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
build_timestamp_trunc,
rename_func,
time_format,
+ unit_to_str,
)
from sqlglot.dialects.mysql import MySQL
@@ -27,7 +28,7 @@ class Doris(MySQL):
}
class Generator(MySQL.Generator):
- CAST_MAPPING = {}
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
@@ -36,8 +37,7 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
- LAST_DAY_SUPPORTS_DATE_PART = False
-
+ CAST_MAPPING = {}
TIMESTAMP_FUNC_TYPES = set()
TRANSFORMS = {
@@ -49,9 +49,7 @@ class Doris(MySQL):
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
- exp.DateTrunc: lambda self, e: self.func(
- "DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
- ),
+ exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.Map: rename_func("ARRAY_MAP"),
@@ -63,9 +61,7 @@ class Doris(MySQL):
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
- ),
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index c1f6afa..0a00d92 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
@@ -12,18 +11,10 @@ from sqlglot.dialects.dialect import (
str_position_sql,
timestrtotime_sql,
)
+from sqlglot.dialects.mysql import date_add_sql
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
-def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
- this = self.sql(expression, "this")
- unit = exp.var(expression.text("unit").upper() or "DAY")
- return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit))
-
- return func
-
-
def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@@ -84,7 +75,6 @@ class Drill(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"),
}
@@ -124,9 +114,9 @@ class Drill(Dialect):
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
- exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateAdd: date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _date_add_sql("SUB"),
+ exp.DateSub: date_add_sql("SUB"),
exp.DateToDi: lambda self,
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f74dc97..6a1d07a 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -26,6 +26,7 @@ from sqlglot.dialects.dialect import (
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
+ unit_to_var,
)
from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@@ -33,15 +34,16 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
- unit = self.sql(expression, "unit").strip("'") or "DAY"
- interval = self.sql(exp.Interval(this=expression.expression, unit=unit))
+ interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression)))
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
-def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_delta_sql(
+ self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd
+) -> str:
this = self.sql(expression, "this")
- unit = self.sql(expression, "unit").strip("'") or "DAY"
- op = "+" if isinstance(expression, exp.DateAdd) else "-"
+ unit = unit_to_var(expression)
+ op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
@@ -186,6 +188,11 @@ class DuckDB(Dialect):
return super().to_json_path(path)
class Tokenizer(tokens.Tokenizer):
+ HEREDOC_STRINGS = ["$"]
+
+ HEREDOC_TAG_IS_IDENTIFIER = True
+ HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"//": TokenType.DIV,
@@ -199,6 +206,7 @@ class DuckDB(Dialect):
"LOGICAL": TokenType.BOOLEAN,
"ONLY": TokenType.ONLY,
"PIVOT_WIDER": TokenType.PIVOT,
+ "POSITIONAL": TokenType.POSITIONAL,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
@@ -227,8 +235,8 @@ class DuckDB(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
- "ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _build_sort_array_desc,
+ "ARRAY_SORT": exp.SortArray.from_arg_list,
"DATEDIFF": _build_date_diff,
"DATE_DIFF": _build_date_diff,
"DATE_TRUNC": date_trunc_to_time,
@@ -285,6 +293,11 @@ class DuckDB(Dialect):
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("DECODE")
+ NO_PAREN_FUNCTION_PARSERS = {
+ **parser.Parser.NO_PAREN_FUNCTION_PARSERS,
+ "MAP": lambda self: self._parse_map(),
+ }
+
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
TokenType.ANTI,
@@ -299,6 +312,13 @@ class DuckDB(Dialect):
),
}
+ def _parse_map(self) -> exp.ToMap | exp.Map:
+ if self._match(TokenType.L_BRACE, advance=False):
+ return self.expression(exp.ToMap, this=self._parse_bracket())
+
+ args = self._parse_wrapped_csv(self._parse_conjunction)
+ return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))
+
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
@@ -345,6 +365,7 @@ class DuckDB(Dialect):
SUPPORTS_CREATE_TABLE_LIKE = False
MULTI_ARG_DISTINCT = False
CAN_IMPLEMENT_ARRAY_ANY = True
+ SUPPORTS_TO_NUMBER = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -425,6 +446,7 @@ class DuckDB(Dialect):
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
),
exp.Struct: _struct_sql,
+ exp.TimeAdd: _date_delta_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
@@ -478,7 +500,7 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
- UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
+ UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)
# DuckDB doesn't generally support CREATE TABLE .. properties
# https://duckdb.org/docs/sql/statements/create_table.html
@@ -569,3 +591,9 @@ class DuckDB(Dialect):
return rename_func("RANGE")(self, expression)
return super().generateseries_sql(expression)
+
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ if isinstance(expression.this, exp.Array):
+ expression.this.replace(exp.paren(expression.this))
+
+ return super().bracket_sql(expression)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 55a9254..cc7debb 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -319,7 +319,9 @@ class Hive(Dialect):
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
- "UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True),
+ "UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
+ args or [exp.CurrentTimestamp()]
+ ),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@@ -431,6 +433,7 @@ class Hive(Dialect):
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
+ SUPPORTS_TO_NUMBER = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@@ -472,7 +475,7 @@ class Hive(Dialect):
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayConcat: rename_func("CONCAT"),
- exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
+ exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 6ebae1e..1d53346 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
build_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
+ unit_to_var,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(
+def date_add_sql(
kind: str,
-) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
- def func(self: MySQL.Generator, expression: exp.Expression) -> str:
- this = self.sql(expression, "this")
- unit = expression.text("unit").upper() or "DAY"
- return (
- f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+) -> t.Callable[[generator.Generator, exp.Expression], str]:
+ def func(self: generator.Generator, expression: exp.Expression) -> str:
+ return self.func(
+ f"DATE_{kind}",
+ expression.this,
+ exp.Interval(this=expression.expression, unit=unit_to_var(expression)),
)
return func
@@ -291,6 +292,7 @@ class MySQL(Dialect):
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
+ "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MAKETIME": exp.TimeFromParts.from_arg_list,
@@ -319,11 +321,7 @@ class MySQL(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"CHAR": lambda self: self._parse_chr(),
- "GROUP_CONCAT": lambda self: self.expression(
- exp.GroupConcat,
- this=self._parse_lambda(),
- separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
- ),
+ "GROUP_CONCAT": lambda self: self._parse_group_concat(),
# https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
"VALUES": lambda self: self.expression(
exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
@@ -412,6 +410,11 @@ class MySQL(Dialect):
"SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
}
+ ALTER_PARSERS = {
+ **parser.Parser.ALTER_PARSERS,
+ "MODIFY": lambda self: self._parse_alter_table_alter(),
+ }
+
SCHEMA_UNNAMED_CONSTRAINTS = {
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
"FULLTEXT",
@@ -458,7 +461,7 @@ class MySQL(Dialect):
this = self._parse_id_var(any_token=False)
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
- schema = self._parse_schema()
+ expressions = self._parse_wrapped_csv(self._parse_ordered)
options = []
while True:
@@ -478,9 +481,6 @@ class MySQL(Dialect):
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
- elif self._match_text_seq("ENGINE_ATTRIBUTE"):
- self._match(TokenType.EQ)
- opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
@@ -495,7 +495,7 @@ class MySQL(Dialect):
return self.expression(
exp.IndexColumnConstraint,
this=this,
- schema=schema,
+ expressions=expressions,
kind=kind,
index_type=index_type,
options=options,
@@ -617,6 +617,39 @@ class MySQL(Dialect):
return self.expression(exp.Chr, **kwargs)
+ def _parse_group_concat(self) -> t.Optional[exp.Expression]:
+ def concat_exprs(
+ node: t.Optional[exp.Expression], exprs: t.List[exp.Expression]
+ ) -> exp.Expression:
+ if isinstance(node, exp.Distinct) and len(node.expressions) > 1:
+ concat_exprs = [
+ self.expression(exp.Concat, expressions=node.expressions, safe=True)
+ ]
+ node.set("expressions", concat_exprs)
+ return node
+ if len(exprs) == 1:
+ return exprs[0]
+ return self.expression(exp.Concat, expressions=args, safe=True)
+
+ args = self._parse_csv(self._parse_lambda)
+
+ if args:
+ order = args[-1] if isinstance(args[-1], exp.Order) else None
+
+ if order:
+ # Order By is the last (or only) expression in the list and has consumed the 'expr' before it,
+ # remove 'expr' from exp.Order and add it back to args
+ args[-1] = order.this
+ order.set("this", concat_exprs(order.this, args))
+
+ this = order or concat_exprs(args[0], args)
+ else:
+ this = None
+
+ separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None
+
+ return self.expression(exp.GroupConcat, this=this, separator=separator)
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = None
@@ -630,6 +663,7 @@ class MySQL(Dialect):
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
+ SUPPORTS_TO_NUMBER = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -637,9 +671,9 @@ class MySQL(Dialect):
exp.DateDiff: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
),
- exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
+ exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
+ exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
exp.Day: _remove_ts_or_ds_to_date(),
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
@@ -672,7 +706,7 @@ class MySQL(Dialect):
exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampDiff: lambda self, e: self.func(
- "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
+ "TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@@ -682,9 +716,10 @@ class MySQL(Dialect):
),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: _date_add_sql("ADD"),
+ exp.TsOrDsAdd: date_add_sql("ADD"),
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.Week: _remove_ts_or_ds_to_date(),
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
exp.Year: _remove_ts_or_ds_to_date(),
@@ -751,11 +786,6 @@ class MySQL(Dialect):
result = f"{result} UNSIGNED"
return result
- def xor_sql(self, expression: exp.Xor) -> str:
- if expression.expressions:
- return self.expressions(expression, sep=" XOR ")
- return super().xor_sql(expression)
-
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index bccdad0..e038400 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
build_formatted_time,
no_ilike_sql,
rename_func,
+ to_number_with_nls_param,
trim_sql,
)
from sqlglot.helper import seq_get
@@ -246,6 +247,7 @@ class Oracle(Dialect):
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.ToNumber: to_number_with_nls_param,
exp.Trim: trim_sql,
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index b53ae07..11398ed 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -278,6 +278,7 @@ class Postgres(Dialect):
"REVOKE": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
+ "NAME": TokenType.NAME,
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
@@ -356,6 +357,16 @@ class Postgres(Dialect):
JSON_ARROWS_REQUIRE_JSON_TYPE = True
+ COLUMN_OPERATORS = {
+ **parser.Parser.COLUMN_OPERATORS,
+ TokenType.ARROW: lambda self, this, path: build_json_extract_path(
+ exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
+ )([this, path]),
+ TokenType.DARROW: lambda self, this, path: build_json_extract_path(
+ exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
+ )([this, path]),
+ }
+
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
while True:
if not self._match(TokenType.L_PAREN):
@@ -484,6 +495,7 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
+ exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 3649bd2..25bba96 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -22,10 +22,13 @@ from sqlglot.dialects.dialect import (
rename_func,
right_to_substring_sql,
struct_extract_sql,
+ str_position_sql,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_add_cast,
+ unit_to_str,
)
+from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@@ -93,14 +96,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
- unit = exp.Literal.string(expression.text("unit") or "DAY")
+ unit = unit_to_str(expression)
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
- unit = exp.Literal.string(expression.text("unit") or "DAY")
+ unit = unit_to_str(expression)
return self.func("DATE_DIFF", unit, expr, this)
@@ -196,6 +199,7 @@ class Presto(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
TABLESAMPLE_SIZE_IS_PERCENT = True
+ LOG_BASE_FIRST: t.Optional[bool] = None
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@@ -289,6 +293,7 @@ class Presto(Dialect):
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
+ SUPPORTS_TO_NUMBER = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@@ -323,6 +328,7 @@ class Presto(Dialect):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
+ exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression),
@@ -339,19 +345,19 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "DAY"),
+ unit_to_str(e),
_to_int(e.expression),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
+ "DATE_DIFF", unit_to_str(e), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self,
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "DAY"),
+ unit_to_str(e),
_to_int(e.expression * -1),
e.this,
),
@@ -397,13 +403,10 @@ class Presto(Dialect):
]
),
exp.SortArray: _no_sort_array,
- exp.StrPosition: rename_func("STRPOS"),
+ exp.StrPosition: lambda self, e: str_position_sql(self, e, generate_instance=True),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
- exp.StrToUnix: lambda self, e: self.func(
- "TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e))
- ),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
@@ -436,6 +439,22 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
+ def strtounix_sql(self, expression: exp.StrToUnix) -> str:
+ # Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
+ # To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
+ # timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
+ # which seems to be using the same time mapping as Hive, as per:
+ # https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
+ value_as_text = exp.cast(expression.this, "text")
+ parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
+ parse_with_tz = self.func(
+ "PARSE_DATETIME",
+ value_as_text,
+ self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
+ )
+ coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
+ return self.func("TO_UNIXTIME", coalesced)
+
def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
return self.func(
@@ -481,8 +500,7 @@ class Presto(Dialect):
return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
def interval_sql(self, expression: exp.Interval) -> str:
- unit = self.sql(expression, "unit")
- if expression.this and unit.startswith("WEEK"):
+ if expression.this and expression.text("unit").upper().startswith("WEEK"):
return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)
diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py
new file mode 100644
index 0000000..3005753
--- /dev/null
+++ b/sqlglot/dialects/prql.py
@@ -0,0 +1,109 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp, parser, tokens
+from sqlglot.dialects.dialect import Dialect
+from sqlglot.tokens import TokenType
+
+
+class PRQL(Dialect):
+ class Tokenizer(tokens.Tokenizer):
+ IDENTIFIERS = ["`"]
+ QUOTES = ["'", '"']
+
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "=": TokenType.ALIAS,
+ "'": TokenType.QUOTE,
+ '"': TokenType.QUOTE,
+ "`": TokenType.IDENTIFIER,
+ "#": TokenType.COMMENT,
+ }
+
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ }
+
+ class Parser(parser.Parser):
+ TRANSFORM_PARSERS = {
+ "DERIVE": lambda self, query: self._parse_selection(query),
+ "SELECT": lambda self, query: self._parse_selection(query, append=False),
+ "TAKE": lambda self, query: self._parse_take(query),
+ }
+
+ def _parse_statement(self) -> t.Optional[exp.Expression]:
+ expression = self._parse_expression()
+ expression = expression if expression else self._parse_query()
+ return expression
+
+ def _parse_query(self) -> t.Optional[exp.Query]:
+ from_ = self._parse_from()
+
+ if not from_:
+ return None
+
+ query = exp.select("*").from_(from_, copy=False)
+
+ while self._match_texts(self.TRANSFORM_PARSERS):
+ query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)
+
+ return query
+
+ def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
+ if self._match(TokenType.L_BRACE):
+ selects = self._parse_csv(self._parse_expression)
+
+ if not self._match(TokenType.R_BRACE, expression=query):
+ self.raise_error("Expecting }")
+ else:
+ expression = self._parse_expression()
+ selects = [expression] if expression else []
+
+ projections = {
+ select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
+ for select in query.selects
+ }
+
+ selects = [
+ select.transform(
+ lambda s: (projections[s.name].copy() if s.name in projections else s)
+ if isinstance(s, exp.Column)
+ else s,
+ copy=False,
+ )
+ for select in selects
+ ]
+
+ return query.select(*selects, append=append, copy=False)
+
+ def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
+ num = self._parse_number() # TODO: TAKE for ranges a..b
+ return query.limit(num) if num else None
+
+ def _parse_expression(self) -> t.Optional[exp.Expression]:
+ if self._next and self._next.token_type == TokenType.ALIAS:
+ alias = self._parse_id_var(True)
+ self._match(TokenType.ALIAS)
+ return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
+ return self._parse_conjunction()
+
+ def _parse_table(
+ self,
+ schema: bool = False,
+ joins: bool = False,
+ alias_tokens: t.Optional[t.Collection[TokenType]] = None,
+ parse_bracket: bool = False,
+ is_db_reference: bool = False,
+ ) -> t.Optional[exp.Expression]:
+ return self._parse_table_parts()
+
+ def _parse_from(
+ self, joins: bool = False, skip_from_token: bool = False
+ ) -> t.Optional[exp.From]:
+ if not skip_from_token and not self._match(TokenType.FROM):
+ return None
+
+ return self.expression(
+ exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
+ )
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 0db87ec..1f0c411 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -92,23 +92,6 @@ class Redshift(Postgres):
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
- def _parse_types(
- self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
- ) -> t.Optional[exp.Expression]:
- this = super()._parse_types(
- check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
- )
-
- if (
- isinstance(this, exp.DataType)
- and this.is_type("varchar")
- and this.expressions
- and this.expressions[0].this == exp.column("MAX")
- ):
- this.set("expressions", [exp.var("MAX")])
-
- return this
-
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
@@ -153,6 +136,7 @@ class Redshift(Postgres):
NVL2_SUPPORTED = True
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = False
+ MULTI_ARG_DISTINCT = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -187,9 +171,13 @@ class Redshift(Postgres):
),
exp.SortKeyProperty: lambda self,
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.StartsWith: lambda self,
+ e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'",
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
+ exp.UnixToTime: lambda self,
+ e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')",
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 20fdfb7..73a9166 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
var_map_sql,
)
-from sqlglot.expressions import Literal
-from sqlglot.helper import flatten, is_int, seq_get
+from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
@@ -29,33 +28,35 @@ if t.TYPE_CHECKING:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
- if len(args) == 2:
- first_arg, second_arg = args
- if second_arg.is_string:
- # case: <string_expr> [ , <format> ]
- return build_formatted_time(exp.StrToTime, "snowflake")(args)
- return exp.UnixToTime(this=first_arg, scale=second_arg)
+def _build_datetime(
+ name: str, kind: exp.DataType.Type, safe: bool = False
+) -> t.Callable[[t.List], exp.Func]:
+ def _builder(args: t.List) -> exp.Func:
+ value = seq_get(args, 0)
+
+ if isinstance(value, exp.Literal):
+ int_value = is_int(value.this)
- from sqlglot.optimizer.simplify import simplify_literals
+ # Converts calls like `TO_TIME('01:02:03')` into casts
+ if len(args) == 1 and value.is_string and not int_value:
+ return exp.cast(value, kind)
- # The first argument might be an expression like 40 * 365 * 86400, so we try to
- # reduce it using `simplify_literals` first and then check if it's a Literal.
- first_arg = seq_get(args, 0)
- if not isinstance(simplify_literals(first_arg, root=True), Literal):
- # case: <variant_expr> or other expressions such as columns
- return exp.TimeStrToTime.from_arg_list(args)
+ # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
+ # cases so we can transpile them, since they're relatively common
+ if kind == exp.DataType.Type.TIMESTAMP:
+ if int_value:
+ return exp.UnixToTime(this=value, scale=seq_get(args, 1))
+ if not is_float(value.this):
+ return build_formatted_time(exp.StrToTime, "snowflake")(args)
- if first_arg.is_string:
- if is_int(first_arg.this):
- # case: <integer>
- return exp.UnixToTime.from_arg_list(args)
+ if len(args) == 2 and kind == exp.DataType.Type.DATE:
+ formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
+ formatted_exp.set("safe", safe)
+ return formatted_exp
- # case: <date_expr>
- return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
+ return exp.Anonymous(this=name, expressions=args)
- # case: <numeric_expr>
- return exp.UnixToTime.from_arg_list(args)
+ return _builder
def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
@@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff:
)
+def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
+ def _builder(args: t.List) -> E:
+ return expr_type(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=_map_date_part(seq_get(args, 0)),
+ )
+
+ return _builder
+
+
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
@@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
- if expression.is_type("array"):
- return "ARRAY"
- elif expression.is_type("map"):
- return "OBJECT"
- return self.datatype_sql(expression)
-
-
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
flag = expression.text("flag")
@@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression:
+ assert isinstance(expression, exp.Create)
+
+ def _flatten_structured_type(expression: exp.DataType) -> exp.DataType:
+ if expression.this in exp.DataType.NESTED_TYPES:
+ expression.set("expressions", None)
+ return expression
+
+ props = expression.args.get("properties")
+ if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)):
+ for schema_expression in expression.this.expressions:
+ if isinstance(schema_expression, exp.ColumnDef):
+ column_type = schema_expression.kind
+ if isinstance(column_type, exp.DataType):
+ column_type.transform(_flatten_structured_type, copy=False)
+
+ return expression
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -312,7 +335,13 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ ID_VAR_TOKENS = {
+ *parser.Parser.ID_VAR_TOKENS,
+ TokenType.MATCH_CONDITION,
+ }
+
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
+ TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION)
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -327,17 +356,13 @@ class Snowflake(Dialect):
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
step=seq_get(args, 2),
),
- "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
"CONVERT_TIMEZONE": _build_convert_timezone,
+ "DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
"DATE_TRUNC": _date_trunc_to_time,
- "DATEADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=_map_date_part(seq_get(args, 0)),
- ),
+ "DATEADD": _build_date_time_add(exp.DateAdd),
"DATEDIFF": _build_datediff,
"DIV0": _build_if_from_div0,
"FLATTEN": exp.Explode.from_arg_list,
@@ -349,17 +374,34 @@ class Snowflake(Dialect):
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
+ "MEDIAN": lambda args: exp.PercentileCont(
+ this=seq_get(args, 0), expression=exp.Literal.number(0.5)
+ ),
"NULLIFZERO": _build_if_from_nullifzero,
"OBJECT_CONSTRUCT": _build_object_construct,
"REGEXP_REPLACE": _build_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
+ "TIMEADD": _build_date_time_add(exp.TimeAdd),
"TIMEDIFF": _build_datediff,
+ "TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
- "TO_TIMESTAMP": _build_to_timestamp,
+ "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
+ "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
+ "TO_NUMBER": lambda args: exp.ToNumber(
+ this=seq_get(args, 0),
+ format=seq_get(args, 1),
+ precision=seq_get(args, 2),
+ scale=seq_get(args, 3),
+ ),
+ "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
+ "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
+ "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
+ "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
+ "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _build_if_from_zeroifnull,
}
@@ -377,7 +419,6 @@ class Snowflake(Dialect):
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
- TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
}
ALTER_PARSERS = {
@@ -434,35 +475,35 @@ class Snowflake(Dialect):
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
- def _parse_colon_get_path(
- self: parser.Parser, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]:
- while True:
- path = self._parse_bitwise() or self._parse_var(any_token=True)
+ def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ this = super()._parse_column_ops(this)
+
+ casts = []
+ json_path = []
+
+ while self._match(TokenType.COLON):
+ path = super()._parse_column_ops(self._parse_field(any_token=True))
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
- if isinstance(path, exp.Cast):
- target_type = path.to
+ while isinstance(path, exp.Cast):
+ casts.append(path.to)
path = path.this
- else:
- target_type = None
- if isinstance(path, exp.Expression):
- path = exp.Literal.string(path.sql(dialect="snowflake"))
+ if path:
+ json_path.append(path.sql(dialect="snowflake", copy=False))
- # The extraction operator : is left-associative
+ if json_path:
this = self.expression(
- exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
+ exp.JSONExtract,
+ this=this,
+ expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
)
- if target_type:
- this = exp.cast(this, target_type)
+ while casts:
+ this = self.expression(exp.Cast, this=this, to=casts.pop())
- if not self._match(TokenType.COLON):
- break
-
- return self._parse_range(this)
+ return this
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
@@ -663,6 +704,7 @@ class Snowflake(Dialect):
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
+ "MATCH_CONDITION": TokenType.MATCH_CONDITION,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
@@ -703,6 +745,7 @@ class Snowflake(Dialect):
LIMIT_ONLY_LITERALS = True
JSON_KEY_VALUE_PAIR_SEP = ","
INSERT_OVERWRITE = " OVERWRITE INTO"
+ STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -711,15 +754,14 @@ class Snowflake(Dialect):
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
- exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
+ exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DateStrToDate: datestrtodate_sql,
- exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@@ -769,6 +811,7 @@ class Snowflake(Dialect):
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Stuff: rename_func("INSERT"),
+ exp.TimeAdd: date_delta_sql("TIMEADD"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
),
@@ -783,6 +826,9 @@ class Snowflake(Dialect):
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
+ exp.TsOrDsToDate: lambda self, e: self.func(
+ "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e)
+ ),
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
@@ -797,6 +843,8 @@ class Snowflake(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.NESTED: "OBJECT",
+ exp.DataType.Type.STRUCT: "OBJECT",
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -811,6 +859,37 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ UNSUPPORTED_VALUES_EXPRESSIONS = {
+ exp.Struct,
+ }
+
+ def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
+ if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS):
+ values_as_table = False
+
+ return super().values_sql(expression, values_as_table=values_as_table)
+
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ expressions = expression.expressions
+ if (
+ expressions
+ and expression.is_type(*exp.DataType.STRUCT_TYPES)
+ and any(isinstance(field_type, exp.DataType) for field_type in expressions)
+ ):
+ # The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ]
+ return "OBJECT"
+
+ return super().datatype_sql(expression)
+
+ def tonumber_sql(self, expression: exp.ToNumber) -> str:
+ return self.func(
+ "TO_NUMBER",
+ expression.this,
+ expression.args.get("format"),
+ expression.args.get("precision"),
+ expression.args.get("scale"),
+ )
+
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
milli = expression.args.get("milli")
if milli is not None:
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 20c0fce..88b5ddc 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
-from sqlglot.dialects.dialect import rename_func
+from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
@@ -78,6 +78,8 @@ class Spark(Spark2):
return this
class Generator(Spark2.Generator):
+ SUPPORTS_TO_NUMBER = True
+
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
@@ -100,7 +102,7 @@ class Spark(Spark2):
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
- "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
+ "DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
@@ -117,11 +119,10 @@ class Spark(Spark2):
return self.function_fallback_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
- unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
- if unit:
- return self.func("DATEDIFF", unit, start, end)
+ if expression.unit:
+ return self.func("DATEDIFF", unit_to_var(expression), start, end)
return self.func("DATEDIFF", end, start)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 63eae6e..069916f 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
pivot_column_names,
rename_func,
trim_sql,
+ unit_to_str,
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@@ -203,6 +204,7 @@ class Spark2(Hive):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.AtTimeZone: lambda self, e: self.func(
"FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
),
@@ -218,7 +220,7 @@ class Spark2(Hive):
]
),
exp.DateFromParts: rename_func("MAKE_DATE"),
- exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@@ -241,9 +243,7 @@ class Spark2(Hive):
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.Trim: trim_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.VariancePop: rename_func("VAR_POP"),
@@ -252,7 +252,6 @@ class Spark2(Hive):
[transforms.remove_within_group_for_percentiles]
),
}
- TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 2b17ff9..ef7d9aa 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -33,6 +33,14 @@ def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> st
return arrow_json_extract_sql(self, expression)
+def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr:
+ if len(args) == 1:
+ args.append(exp.CurrentTimestamp())
+ if len(args) == 2:
+ return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0])
+ return exp.Anonymous(this="STRFTIME", expressions=args)
+
+
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
@@ -82,6 +90,7 @@ class SQLite(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
+ "STRFTIME": _build_strftime,
}
STRING_ALIASES = True
@@ -93,6 +102,7 @@ class SQLite(Dialect):
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
SUPPORTS_TABLE_ALIAS_COLUMNS = False
+ SUPPORTS_TO_NUMBER = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
@@ -151,7 +161,9 @@ class SQLite(Dialect):
),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
+ exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this),
exp.TryCast: no_trycast_sql,
+ exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"),
}
# SQLite doesn't generally support CREATE TABLE .. properties
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 12ac600..5691f58 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
build_timestamp_trunc,
rename_func,
+ unit_to_str,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@@ -39,15 +40,13 @@ class StarRocks(MySQL):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression
+ "DATE_DIFF", unit_to_str(e), e.this, e.expression
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index b736918..40feb67 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -5,6 +5,8 @@ from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
+ LOG_BASE_FIRST = False
+
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = [("[", "]")]
QUOTES = ["'", '"']
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 0663a1d..a65e10e 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -3,7 +3,13 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
+from sqlglot.dialects.dialect import (
+ Dialect,
+ max_or_greatest,
+ min_or_least,
+ rename_func,
+ to_number_with_nls_param,
+)
from sqlglot.tokens import TokenType
@@ -206,6 +212,7 @@ class Teradata(Dialect):
exp.StrToDate: lambda self,
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.ToNumber: to_number_with_nls_param,
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 1bbed67..457e2f0 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False
+ LOG_BASE_FIRST = True
class Generator(Presto.Generator):
TRANSFORMS = {
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b6f491f..8e06be6 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
any_value_to_max_sql,
date_delta_sql,
+ datestrtodate_sql,
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
@@ -724,6 +725,7 @@ class TSQL(Dialect):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
+ SUPPORTS_TO_NUMBER = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -760,12 +762,14 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
+ exp.ArrayToString: rename_func("STRING_AGG"),
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
@@ -808,6 +812,22 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def select_sql(self, expression: exp.Select) -> str:
+ if expression.args.get("offset"):
+ if not expression.args.get("order"):
+ # ORDER BY is required in order to use OFFSET in a query, so we use
+ # a noop order by, since we don't really care about the order.
+ # See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819
+ expression.order_by(exp.select(exp.null()).subquery(), copy=False)
+
+ limit = expression.args.get("limit")
+ if isinstance(limit, exp.Limit):
+ # TOP and OFFSET can't be combined, we need use FETCH instead of TOP
+ # we replace here because otherwise TOP would be generated in select_sql
+ limit.replace(exp.Fetch(direction="FIRST", count=limit.expression))
+
+ return super().select_sql(expression)
+
def convert_sql(self, expression: exp.Convert) -> str:
name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
return self.func(
@@ -862,12 +882,12 @@ class TSQL(Dialect):
return rename_func("DATETIMEFROMPARTS")(self, expression)
- def set_operation(self, expression: exp.Union, op: str) -> str:
+ def set_operations(self, expression: exp.Union) -> str:
limit = expression.args.get("limit")
if limit:
return self.sql(expression.limit(limit.pop(), copy=False))
- return super().set_operation(expression, op)
+ return super().set_operations(expression)
def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index bda9136..22c506a 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -103,7 +103,7 @@ def diff(
) -> t.Dict[int, exp.Expression]:
return {
id(old_node): new_node
- for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk())
+ for old_node, new_node in zip(original.walk(), copy.walk())
if id(old_node) in matching_ids
}
@@ -158,14 +158,10 @@ class ChangeDistiller:
self._source = source
self._target = target
self._source_index = {
- id(n): n
- for n, *_ in self._source.bfs()
- if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
+ id(n): n for n in self._source.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
}
self._target_index = {
- id(n): n
- for n, *_ in self._target.bfs()
- if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
+ id(n): n for n in self._target.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
@@ -216,10 +212,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
- id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
+ id(n): None for n in self._source.bfs() if id(n) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
- id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
+ id(n): None for n in self._target.bfs() if id(n) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@@ -322,7 +318,7 @@ class ChangeDistiller:
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
- for _, node in expression.iter_expressions():
+ for node in expression.iter_expressions():
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
has_child_exprs = True
yield from _get_leaves(node)
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index c8f9148..29c0e68 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -10,11 +10,13 @@ import logging
import time
import typing as t
+from sqlglot import exp
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
from sqlglot.helper import dict_depth
from sqlglot.optimizer import optimize
+from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.planner import Plan
from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
@@ -26,15 +28,11 @@ if t.TYPE_CHECKING:
from sqlglot.schema import Schema
-PYTHON_TYPE_TO_SQLGLOT = {
- "dict": "MAP",
-}
-
-
def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
read: DialectType = None,
+ dialect: DialectType = None,
tables: t.Optional[t.Dict] = None,
) -> Table:
"""
@@ -48,11 +46,13 @@ def execute(
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
+ dialect: the SQL dialect (alias for read).
tables: additional tables to register.
Returns:
Simple columnar data structure.
"""
+ read = read or dialect
tables_ = ensure_tables(tables, dialect=read)
if not schema:
@@ -64,8 +64,9 @@ def execute(
assert table is not None
for column in table.columns:
- py_type = type(table[0][column]).__name__
- nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)
+ value = table[0][column]
+ column_type = annotate_types(exp.convert(value)).type or type(value).__name__
+ nested_set(schema, [*keys, column], column_type)
schema = ensure_schema(schema, dialect=read)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 218a8e0..c51049b 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -106,6 +106,13 @@ def cast(this, to):
return this
if isinstance(this, str):
return datetime.date.fromisoformat(this)
+ if to == exp.DataType.Type.TIME:
+ if isinstance(this, datetime.datetime):
+ return this.time()
+ if isinstance(this, datetime.time):
+ return this
+ if isinstance(this, str):
+ return datetime.time.fromisoformat(this)
if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
if isinstance(this, datetime.datetime):
return this
@@ -139,7 +146,7 @@ def interval(this, unit):
@null_if_any("this", "expression")
-def arrayjoin(this, expression, null=None):
+def arraytostring(this, expression, null=None):
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
@@ -173,7 +180,7 @@ ENV = {
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
- "ARRAYJOIN": arrayjoin,
+ "ARRAYTOSTRING": arraytostring,
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
@@ -212,6 +219,7 @@ ENV = {
"ORDERED": ordered,
"POW": pow,
"RIGHT": null_if_any(lambda this, e: this[-e:]),
+ "ROUND": null_if_any(lambda this, decimals=None, truncate=None: round(this, ndigits=decimals)),
"STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring,
@@ -225,10 +233,12 @@ ENV = {
"CURRENTTIME": datetime.datetime.now,
"CURRENTDATE": datetime.date.today,
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
+ "STRTOTIME": null_if_any(lambda arg, format: datetime.datetime.strptime(arg, format)),
"TRIM": null_if_any(lambda this, e=None: this.strip(e)),
"STRUCT": lambda *args: {
args[x]: args[x + 1]
for x in range(0, len(args), 2)
if (args[x + 1] is not None and args[x] is not None)
},
+ "UNIXTOTIME": null_if_any(lambda arg: datetime.datetime.utcfromtimestamp(arg)),
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index a2b23d4..674ef78 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -157,7 +157,7 @@ class PythonExecutor:
yield context.table.reader
def join(self, step, context):
- source = step.name
+ source = step.source_name
source_table = context.tables[source]
source_context = self.context({source: source_table})
@@ -398,7 +398,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
lambda n: (
exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
)
- )
+ ).assert_is(exp.Lambda)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 1a24875..e79c04b 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -39,6 +39,8 @@ if t.TYPE_CHECKING:
from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import DialectType
+ Q = t.TypeVar("Q", bound="Query")
+
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@@ -72,6 +74,7 @@ class Expression(metaclass=_Expression):
parent: a reference to the parent expression (or None, in case of root expressions).
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
uses to refer to it.
+ index: the index of an expression if it is inside of a list argument in its parent.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
@@ -91,12 +94,13 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
+ __slots__ = ("args", "parent", "arg_key", "index", "comments", "_type", "_meta", "_hash")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
self.parent: t.Optional[Expression] = None
self.arg_key: t.Optional[str] = None
+ self.index: t.Optional[int] = None
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
@@ -248,25 +252,44 @@ class Expression(metaclass=_Expression):
return self._meta
def __deepcopy__(self, memo):
- copy = self.__class__(**deepcopy(self.args))
- if self.comments is not None:
- copy.comments = deepcopy(self.comments)
-
- if self._type is not None:
- copy._type = self._type.copy()
-
- if self._meta is not None:
- copy._meta = deepcopy(self._meta)
-
- return copy
+ root = self.__class__()
+ stack = [(self, root)]
+
+ while stack:
+ node, copy = stack.pop()
+
+ if node.comments is not None:
+ copy.comments = deepcopy(node.comments)
+ if node._type is not None:
+ copy._type = deepcopy(node._type)
+ if node._meta is not None:
+ copy._meta = deepcopy(node._meta)
+ if node._hash is not None:
+ copy._hash = node._hash
+
+ for k, vs in node.args.items():
+ if hasattr(vs, "parent"):
+ stack.append((vs, vs.__class__()))
+ copy.set(k, stack[-1][-1])
+ elif type(vs) is list:
+ copy.args[k] = []
+
+ for v in vs:
+ if hasattr(v, "parent"):
+ stack.append((v, v.__class__()))
+ copy.append(k, stack[-1][-1])
+ else:
+ copy.append(k, v)
+ else:
+ copy.args[k] = vs
+
+ return root
def copy(self):
"""
Returns a deep copy of the expression.
"""
- new = deepcopy(self)
- new.parent = self.parent
- return new
+ return deepcopy(self)
def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
if self.comments is None:
@@ -289,35 +312,59 @@ class Expression(metaclass=_Expression):
arg_key (str): name of the list expression arg
value (Any): value to append to the list
"""
- if not isinstance(self.args.get(arg_key), list):
+ if type(self.args.get(arg_key)) is not list:
self.args[arg_key] = []
- self.args[arg_key].append(value)
self._set_parent(arg_key, value)
+ values = self.args[arg_key]
+ if hasattr(value, "parent"):
+ value.index = len(values)
+ values.append(value)
- def set(self, arg_key: str, value: t.Any) -> None:
+ def set(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None:
"""
Sets arg_key to value.
Args:
arg_key: name of the expression arg.
value: value to set the arg to.
- """
- if value is None:
+ index: if the arg is a list, this specifies what position to add the value in it.
+ """
+ if index is not None:
+ expressions = self.args.get(arg_key) or []
+
+ if seq_get(expressions, index) is None:
+ return
+ if value is None:
+ expressions.pop(index)
+ for v in expressions[index:]:
+ v.index = v.index - 1
+ return
+
+ if isinstance(value, list):
+ expressions.pop(index)
+ expressions[index:index] = value
+ else:
+ expressions[index] = value
+
+ value = expressions
+ elif value is None:
self.args.pop(arg_key, None)
return
self.args[arg_key] = value
- self._set_parent(arg_key, value)
+ self._set_parent(arg_key, value, index)
- def _set_parent(self, arg_key: str, value: t.Any) -> None:
+ def _set_parent(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None:
if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
+ value.index = index
elif type(value) is list:
- for v in value:
+ for index, v in enumerate(value):
if hasattr(v, "parent"):
v.parent = self
v.arg_key = arg_key
+ v.index = index
@property
def depth(self) -> int:
@@ -328,16 +375,17 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
- def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
+ def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]:
"""Yields the key and expression for all arguments, exploding list args."""
- for k, vs in self.args.items():
+ # remove tuple when python 3.7 is deprecated
+ for vs in reversed(tuple(self.args.values())) if reverse else self.args.values():
if type(vs) is list:
- for v in vs:
+ for v in reversed(vs) if reverse else vs:
if hasattr(v, "parent"):
- yield k, v
+ yield v
else:
if hasattr(vs, "parent"):
- yield k, vs
+ yield vs
def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]:
"""
@@ -365,7 +413,7 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
- for expression, *_ in self.walk(bfs=bfs):
+ for expression in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
@@ -405,15 +453,17 @@ class Expression(metaclass=_Expression):
expression = expression.parent
return expression
- def walk(self, bfs=True, prune=None):
+ def walk(
+ self, bfs: bool = True, prune: t.Optional[t.Callable[[Expression], bool]] = None
+ ) -> t.Iterator[Expression]:
"""
Returns a generator object which visits all nodes in this tree.
Args:
- bfs (bool): if set to True the BFS traversal order will be applied,
+ bfs: if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if
- the generator should stop traversing this branch of the tree.
+ prune: callable that returns True if the generator should stop traversing
+ this branch of the tree.
Returns:
the generator object.
@@ -423,7 +473,9 @@ class Expression(metaclass=_Expression):
else:
yield from self.dfs(prune=prune)
- def dfs(self, parent=None, key=None, prune=None):
+ def dfs(
+ self, prune: t.Optional[t.Callable[[Expression], bool]] = None
+ ) -> t.Iterator[Expression]:
"""
Returns a generator object which visits all nodes in this tree in
the DFS (Depth-first) order.
@@ -431,15 +483,22 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
- parent = parent or self.parent
- yield self, parent, key
- if prune and prune(self, parent, key):
- return
+ stack = [self]
+
+ while stack:
+ node = stack.pop()
- for k, v in self.iter_expressions():
- yield from v.dfs(self, k, prune)
+ yield node
- def bfs(self, prune=None):
+ if prune and prune(node):
+ continue
+
+ for v in node.iter_expressions(reverse=True):
+ stack.append(v)
+
+ def bfs(
+ self, prune: t.Optional[t.Callable[[Expression], bool]] = None
+ ) -> t.Iterator[Expression]:
"""
Returns a generator object which visits all nodes in this tree in
the BFS (Breadth-first) order.
@@ -447,17 +506,18 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
- queue = deque([(self, self.parent, None)])
+ queue = deque([self])
while queue:
- item, parent, key = queue.popleft()
+ node = queue.popleft()
- yield item, parent, key
- if prune and prune(item, parent, key):
+ yield node
+
+ if prune and prune(node):
continue
- for k, v in item.iter_expressions():
- queue.append((v, item, k))
+ for v in node.iter_expressions():
+ queue.append(v)
def unnest(self):
"""
@@ -480,7 +540,7 @@ class Expression(metaclass=_Expression):
"""
Returns unnested operands as a tuple.
"""
- return tuple(arg.unnest() for _, arg in self.iter_expressions())
+ return tuple(arg.unnest() for arg in self.iter_expressions())
def flatten(self, unnest=True):
"""
@@ -488,7 +548,7 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__):
+ for node in self.dfs(prune=lambda n: n.parent and type(n) is not self.__class__):
if type(node) is not self.__class__:
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
@@ -520,32 +580,35 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect).generate(self, **opts)
- def transform(self, fun, *args, copy=True, **kwargs):
+ def transform(self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs) -> Expression:
"""
- Recursively visits all tree nodes (excluding already transformed ones)
+ Visits all tree nodes (excluding already transformed ones)
and applies the given transformation function to each node.
Args:
- fun (function): a function which takes a node as an argument and returns a
+ fun: a function which takes a node as an argument and returns a
new transformed node or the same node without modifications. If the function
returns None, then the corresponding node will be removed from the syntax tree.
- copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
+ copy: if set to True a new tree instance is constructed, otherwise the tree is
modified in place.
Returns:
The transformed tree.
"""
- node = self.copy() if copy else self
- new_node = fun(node, *args, **kwargs)
+ root = None
+ new_node = None
- if new_node is None or not isinstance(new_node, Expression):
- return new_node
- if new_node is not node:
- new_node.parent = node.parent
- return new_node
+ for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node):
+ parent, arg_key, index = node.parent, node.arg_key, node.index
+ new_node = fun(node, *args, **kwargs)
+
+ if not root:
+ root = new_node
+ elif new_node is not node:
+ parent.set(arg_key, new_node, index)
- replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
- return new_node
+ assert root
+ return root.assert_is(Expression)
@t.overload
def replace(self, expression: E) -> E: ...
@@ -572,13 +635,26 @@ class Expression(metaclass=_Expression):
Returns:
The new expression or expressions.
"""
- if not self.parent:
+ parent = self.parent
+
+ if not parent or parent is expression:
return expression
- parent = self.parent
- self.parent = None
+ key = self.arg_key
+ value = parent.args.get(key)
+
+ if type(expression) is list and isinstance(value, Expression):
+ # We are trying to replace an Expression with a list, so it's assumed that
+ # the intention was to really replace the parent of this expression.
+ value.parent.replace(expression)
+ else:
+ parent.set(key, expression, self.index)
+
+ if expression is not self:
+ self.parent = None
+ self.arg_key = None
+ self.index = None
- replace_children(parent, lambda child: expression if child is self else child)
return expression
def pop(self: E) -> E:
@@ -816,6 +892,9 @@ class Expression(metaclass=_Expression):
div.args["safe"] = safe
return div
+ def asc(self, nulls_first: bool = True) -> Ordered:
+ return Ordered(this=self.copy(), nulls_first=nulls_first)
+
def desc(self, nulls_first: bool = False) -> Ordered:
return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first)
@@ -983,13 +1062,13 @@ class Query(Expression):
raise NotImplementedError("Query objects must implement `named_selects`")
def select(
- self,
+ self: Q,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
- ) -> Query:
+ ) -> Q:
"""
Append to or set the SELECT expressions.
@@ -1012,7 +1091,7 @@ class Query(Expression):
raise NotImplementedError("Query objects must implement `select`")
def with_(
- self,
+ self: Q,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
@@ -1020,7 +1099,7 @@ class Query(Expression):
dialect: DialectType = None,
copy: bool = True,
**opts,
- ) -> Query:
+ ) -> Q:
"""
Append to or set the common table expressions.
@@ -1222,6 +1301,18 @@ class Create(DDL):
return kind and kind.upper()
+class SequenceProperties(Expression):
+ arg_types = {
+ "increment": False,
+ "minvalue": False,
+ "maxvalue": False,
+ "cache": False,
+ "start": False,
+ "owned": False,
+ "options": False,
+ }
+
+
class TruncateTable(Expression):
arg_types = {
"expressions": True,
@@ -1243,7 +1334,7 @@ class Clone(Expression):
class Describe(Expression):
- arg_types = {"this": True, "extended": False, "kind": False, "expressions": False}
+ arg_types = {"this": True, "style": False, "kind": False, "expressions": False}
class Kill(Expression):
@@ -1321,7 +1412,12 @@ class WithinGroup(Expression):
# clickhouse supports scalar ctes
# https://clickhouse.com/docs/en/sql-reference/statements/select/with
class CTE(DerivedTable):
- arg_types = {"this": True, "alias": True, "scalar": False}
+ arg_types = {
+ "this": True,
+ "alias": True,
+ "scalar": False,
+ "materialized": False,
+ }
class TableAlias(Expression):
@@ -1541,6 +1637,15 @@ class EncodeColumnConstraint(ColumnConstraintKind):
pass
+# https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
+class ExcludeColumnConstraint(ColumnConstraintKind):
+ pass
+
+
+class WithOperator(Expression):
+ arg_types = {"this": True, "op": True}
+
+
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {
@@ -1560,13 +1665,16 @@ class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
+# https://github.com/ClickHouse/ClickHouse/blob/master/src/Parsers/ParserCreateQuery.h#L646
class IndexColumnConstraint(ColumnConstraintKind):
arg_types = {
"this": False,
- "schema": True,
+ "expressions": False,
"kind": False,
"index_type": False,
"options": False,
+ "expression": False, # Clickhouse
+ "granularity": False,
}
@@ -1605,7 +1713,7 @@ class TitleColumnConstraint(ColumnConstraintKind):
class UniqueColumnConstraint(ColumnConstraintKind):
- arg_types = {"this": False, "index_type": False}
+ arg_types = {"this": False, "index_type": False, "on_conflict": False}
class UppercaseColumnConstraint(ColumnConstraintKind):
@@ -1714,6 +1822,7 @@ class Drop(Expression):
arg_types = {
"this": False,
"kind": False,
+ "expressions": False,
"exists": False,
"temporary": False,
"materialized": False,
@@ -1733,7 +1842,7 @@ class Check(Expression):
# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
class Connect(Expression):
- arg_types = {"start": False, "connect": True}
+ arg_types = {"start": False, "connect": True, "nocycle": False}
class Prior(Expression):
@@ -1815,20 +1924,30 @@ class Index(Expression):
arg_types = {
"this": False,
"table": False,
- "using": False,
- "where": False,
- "columns": False,
"unique": False,
"primary": False,
"amp": False, # teradata
+ "params": False,
+ }
+
+
+class IndexParameters(Expression):
+ arg_types = {
+ "using": False,
"include": False,
- "partition_by": False, # teradata
+ "columns": False,
+ "with_storage": False,
+ "partition_by": False,
+ "tablespace": False,
+ "where": False,
}
class Insert(DDL, DML):
arg_types = {
+ "hint": False,
"with": False,
+ "is_function": False,
"this": True,
"expression": False,
"conflict": False,
@@ -1883,8 +2002,8 @@ class OnConflict(Expression):
arg_types = {
"duplicate": False,
"expressions": False,
- "nothing": False,
- "key": False,
+ "action": False,
+ "conflict_keys": False,
"constraint": False,
}
@@ -1981,6 +2100,7 @@ class Join(Expression):
"method": False,
"global": False,
"hint": False,
+ "match_condition": False, # Snowflake
}
@property
@@ -2173,6 +2293,10 @@ class AutoRefreshProperty(Property):
arg_types = {"this": True}
+class BackupProperty(Property):
+ arg_types = {"this": True}
+
+
class BlockCompressionProperty(Property):
arg_types = {
"autotemp": False,
@@ -2253,6 +2377,14 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
+class GlobalProperty(Property):
+ arg_types = {}
+
+
+class IcebergProperty(Property):
+ arg_types = {}
+
+
class InheritsProperty(Property):
arg_types = {"expressions": True}
@@ -2266,13 +2398,7 @@ class OutputModelProperty(Property):
class IsolatedLoadingProperty(Property):
- arg_types = {
- "no": False,
- "concurrent": False,
- "for_all": False,
- "for_insert": False,
- "for_none": False,
- }
+ arg_types = {"no": False, "concurrent": False, "target": False}
class JournalProperty(Property):
@@ -2436,6 +2562,10 @@ class SetProperty(Property):
arg_types = {"multi": True}
+class SharingProperty(Property):
+ arg_types = {"this": False}
+
+
class SetConfigProperty(Property):
arg_types = {"this": True}
@@ -2472,6 +2602,15 @@ class TransientProperty(Property):
arg_types = {"this": False}
+class UnloggedProperty(Property):
+ arg_types = {}
+
+
+# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-view-transact-sql?view=sql-server-ver16
+class ViewAttributeProperty(Property):
+ arg_types = {"this": True}
+
+
class VolatileProperty(Property):
arg_types = {"this": False}
@@ -3630,6 +3769,10 @@ class SessionParameter(Condition):
class Placeholder(Condition):
arg_types = {"this": False, "kind": False}
+ @property
+ def name(self) -> str:
+ return self.this or "?"
+
class Null(Condition):
arg_types: t.Dict[str, t.Any] = {}
@@ -3714,6 +3857,7 @@ class DataType(Expression):
MEDIUMINT = auto()
MEDIUMTEXT = auto()
MONEY = auto()
+ NAME = auto()
NCHAR = auto()
NESTED = auto()
NULL = auto()
@@ -3764,47 +3908,85 @@ class DataType(Expression):
XML = auto()
YEAR = auto()
+ STRUCT_TYPES = {
+ Type.NESTED,
+ Type.OBJECT,
+ Type.STRUCT,
+ }
+
+ NESTED_TYPES = {
+ *STRUCT_TYPES,
+ Type.ARRAY,
+ Type.MAP,
+ }
+
TEXT_TYPES = {
Type.CHAR,
Type.NCHAR,
- Type.VARCHAR,
Type.NVARCHAR,
Type.TEXT,
+ Type.VARCHAR,
+ Type.NAME,
}
- INTEGER_TYPES = {
- Type.INT,
- Type.TINYINT,
- Type.SMALLINT,
+ SIGNED_INTEGER_TYPES = {
Type.BIGINT,
+ Type.INT,
Type.INT128,
Type.INT256,
+ Type.MEDIUMINT,
+ Type.SMALLINT,
+ Type.TINYINT,
+ }
+
+ UNSIGNED_INTEGER_TYPES = {
+ Type.UBIGINT,
+ Type.UINT,
+ Type.UINT128,
+ Type.UINT256,
+ Type.UMEDIUMINT,
+ Type.USMALLINT,
+ Type.UTINYINT,
+ }
+
+ INTEGER_TYPES = {
+ *SIGNED_INTEGER_TYPES,
+ *UNSIGNED_INTEGER_TYPES,
Type.BIT,
}
FLOAT_TYPES = {
- Type.FLOAT,
Type.DOUBLE,
+ Type.FLOAT,
+ }
+
+ REAL_TYPES = {
+ *FLOAT_TYPES,
+ Type.BIGDECIMAL,
+ Type.DECIMAL,
+ Type.MONEY,
+ Type.SMALLMONEY,
+ Type.UDECIMAL,
}
NUMERIC_TYPES = {
*INTEGER_TYPES,
- *FLOAT_TYPES,
+ *REAL_TYPES,
}
TEMPORAL_TYPES = {
+ Type.DATE,
+ Type.DATE32,
+ Type.DATETIME,
+ Type.DATETIME64,
Type.TIME,
- Type.TIMETZ,
Type.TIMESTAMP,
- Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
- Type.TIMESTAMP_S,
+ Type.TIMESTAMPTZ,
Type.TIMESTAMP_MS,
Type.TIMESTAMP_NS,
- Type.DATE,
- Type.DATE32,
- Type.DATETIME,
- Type.DATETIME64,
+ Type.TIMESTAMP_S,
+ Type.TIMETZ,
}
@classmethod
@@ -4163,8 +4345,6 @@ class Not(Unary):
class Paren(Unary):
- arg_types = {"this": True, "with": False}
-
@property
def output_name(self) -> str:
return self.this.name
@@ -4277,7 +4457,7 @@ class TimeUnit(Expression):
super().__init__(**args)
@property
- def unit(self) -> t.Optional[Var]:
+ def unit(self) -> t.Optional[Var | IntervalSpan]:
return self.args.get("unit")
@@ -4451,6 +4631,18 @@ class ToChar(Func):
arg_types = {"this": True, "format": False, "nlsparam": False}
+# https://docs.snowflake.com/en/sql-reference/functions/to_decimal
+# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_NUMBER.html
+class ToNumber(Func):
+ arg_types = {
+ "this": True,
+ "format": False,
+ "nlsparam": False,
+ "precision": False,
+ "scale": False,
+ }
+
+
# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax
class Convert(Func):
arg_types = {"this": True, "expression": True, "style": False}
@@ -4496,8 +4688,9 @@ class ArrayFilter(Func):
_sql_names = ["FILTER", "ARRAY_FILTER"]
-class ArrayJoin(Func):
+class ArrayToString(Func):
arg_types = {"this": True, "expression": True, "null": False}
+ _sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"]
class ArrayOverlaps(Binary, Func):
@@ -4580,7 +4773,13 @@ class Case(Func):
class Cast(Func):
- arg_types = {"this": True, "to": True, "format": False, "safe": False}
+ arg_types = {
+ "this": True,
+ "to": True,
+ "format": False,
+ "safe": False,
+ "action": False,
+ }
@property
def name(self) -> str:
@@ -4889,6 +5088,10 @@ class ToBase64(Func):
pass
+class GenerateDateArray(Func):
+ arg_types = {"start": True, "end": True, "interval": False}
+
+
class Greatest(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -5142,14 +5345,6 @@ class Log(Func):
arg_types = {"this": True, "expression": False}
-class Log2(Func):
- pass
-
-
-class Log10(Func):
- pass
-
-
class LogicalOr(AggFunc):
_sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"]
@@ -5176,6 +5371,11 @@ class Map(Func):
return values.expressions if values else []
+# Represents the MAP {...} syntax in DuckDB - basically convert a struct to a MAP
+class ToMap(Func):
+ pass
+
+
class MapFromEntries(Func):
pass
@@ -5501,13 +5701,17 @@ class TsOrDsToDateStr(Func):
class TsOrDsToDate(Func):
- arg_types = {"this": True, "format": False}
+ arg_types = {"this": True, "format": False, "safe": False}
class TsOrDsToTime(Func):
pass
+class TsOrDsToTimestamp(Func):
+ pass
+
+
class TsOrDiToDi(Func):
pass
@@ -5528,7 +5732,14 @@ class UnixToStr(Func):
# https://prestodb.io/docs/current/functions/datetime.html
# presto has weird zone/hours/minutes
class UnixToTime(Func):
- arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
+ arg_types = {
+ "this": True,
+ "scale": False,
+ "zone": False,
+ "hours": False,
+ "minutes": False,
+ "format": False,
+ }
SECONDS = Literal.number(0)
DECIS = Literal.number(1)
@@ -5565,6 +5776,10 @@ class Upper(Func):
_sql_names = ["UPPER", "UCASE"]
+class Corr(Binary, AggFunc):
+ pass
+
+
class Variance(AggFunc):
_sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"]
@@ -5573,6 +5788,14 @@ class VariancePop(AggFunc):
_sql_names = ["VARIANCE_POP", "VAR_POP"]
+class CovarSamp(Binary, AggFunc):
+ pass
+
+
+class CovarPop(Binary, AggFunc):
+ pass
+
+
class Week(Func):
arg_types = {"this": True, "mode": False}
@@ -6516,7 +6739,7 @@ def subquery(
**opts,
) -> Select:
"""
- Build a subquery expression.
+ Build a subquery expression that's selected from.
Example:
>>> subquery('select x from tbl', 'bar').select('x').sql()
@@ -6766,7 +6989,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
copy: Whether to copy `value` (only applies to Expressions and collections).
Returns:
- Expression: the equivalent expression object.
+ The equivalent expression object.
"""
if isinstance(value, Expression):
return maybe_copy(value, copy)
@@ -6778,15 +7001,28 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
return null()
if isinstance(value, numbers.Number):
return Literal.number(value)
+ if isinstance(value, bytes):
+ return HexString(this=value.hex())
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
- (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
+ (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat(
+ sep=" "
+ )
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
+ if hasattr(value, "_fields"):
+ return Struct(
+ expressions=[
+ PropertyEQ(
+ this=to_identifier(k), expression=convert(getattr(value, k), copy=copy)
+ )
+ for k in value._fields
+ ]
+ )
return Tuple(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v, copy=copy) for v in value])
@@ -6795,6 +7031,13 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
keys=Array(expressions=[convert(k, copy=copy) for k in value]),
values=Array(expressions=[convert(v, copy=copy) for v in value.values()]),
)
+ if hasattr(value, "__dict__"):
+ return Struct(
+ expressions=[
+ PropertyEQ(this=to_identifier(k), expression=convert(v, copy=copy))
+ for k, v in value.__dict__.items()
+ ]
+ )
raise ValueError(f"Cannot convert {value}")
@@ -6802,7 +7045,7 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
- for k, v in expression.args.items():
+ for k, v in tuple(expression.args.items()):
is_list_arg = type(v) is list
child_nodes = v if is_list_arg else [v]
@@ -6812,12 +7055,36 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -
if isinstance(cn, Expression):
for child_node in ensure_collection(fun(cn, *args, **kwargs)):
new_child_nodes.append(child_node)
- child_node.parent = expression
- child_node.arg_key = k
else:
new_child_nodes.append(cn)
- expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
+ expression.set(k, new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0))
+
+
+def replace_tree(
+ expression: Expression,
+ fun: t.Callable,
+ prune: t.Optional[t.Callable[[Expression], bool]] = None,
+) -> Expression:
+ """
+ Replace an entire tree with the result of function calls on each node.
+
+ This will be traversed in reverse dfs, so leaves first.
+ If new nodes are created as a result of function calls, they will also be traversed.
+ """
+ stack = list(expression.dfs(prune=prune))
+
+ while stack:
+ node = stack.pop()
+ new_node = fun(node)
+
+ if new_node is not node:
+ node.replace(new_node)
+
+ if isinstance(new_node, Expression):
+ stack.append(new_node)
+
+ return new_node
def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
@@ -6936,7 +7203,7 @@ def replace_tables(
return table
return node
- return expression.transform(_replace_tables, copy=copy)
+ return expression.transform(_replace_tables, copy=copy) # type: ignore
def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
@@ -6961,8 +7228,8 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
def _replace_placeholders(node: Expression, args, **kwargs) -> Expression:
if isinstance(node, Placeholder):
- if node.name:
- new_name = kwargs.get(node.name)
+ if node.this:
+ new_name = kwargs.get(node.this)
if new_name is not None:
return convert(new_name)
else:
@@ -7193,3 +7460,15 @@ def null() -> Null:
Returns a Null expression.
"""
return Null()
+
+
+NONNULL_CONSTANTS = (
+ Literal,
+ Boolean,
+)
+
+CONSTANTS = (
+ Literal,
+ Boolean,
+ Null,
+)
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index e6f5c4b..76d9b5d 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -46,9 +46,11 @@ class Generator(metaclass=_Generator):
'safe': Only quote identifiers that are case insensitive.
normalize: Whether to normalize identifiers to lowercase.
Default: False.
- pad: The pad size in a formatted string.
+ pad: The pad size in a formatted string. For example, this affects the indentation of
+ a projection in a query, relative to its nesting level.
Default: 2.
- indent: The indentation size in a formatted string.
+ indent: The indentation size in a formatted string. For example, this affects the
+ indentation of subqueries and filters under a `WHERE` clause.
Default: 2.
normalize_functions: How to normalize function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
@@ -73,6 +75,7 @@ class Generator(metaclass=_Generator):
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
**JSON_PATH_PART_TRANSFORMS,
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
+ exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda _,
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
@@ -83,15 +86,15 @@ class Generator(metaclass=_Generator):
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
- exp.DateAdd: lambda self, e: self.func(
- "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
- ),
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
+ exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda *_: "EXTERNAL",
+ exp.GlobalProperty: lambda *_: "GLOBAL",
exp.HeapProperty: lambda *_: "HEAP",
+ exp.IcebergProperty: lambda *_: "ICEBERG",
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
@@ -123,6 +126,7 @@ class Generator(metaclass=_Generator):
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
+ exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}",
exp.SqlReadWriteProperty: lambda _, e: e.name,
exp.SqlSecurityProperty: lambda _,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
@@ -130,13 +134,17 @@ class Generator(metaclass=_Generator):
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
+ exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda *_: "TRANSIENT",
exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
+ exp.UnloggedProperty: lambda *_: "UNLOGGED",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
+ exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
+ exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}
# Whether null ordering is supported in order by
@@ -321,6 +329,9 @@ class Generator(metaclass=_Generator):
# Whether any(f(x) for x in array) can be implemented by this dialect
CAN_IMPLEMENT_ARRAY_ANY = False
+ # Whether the function TO_NUMBER is supported
+ SUPPORTS_TO_NUMBER = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -350,6 +361,18 @@ class Generator(metaclass=_Generator):
"YEARS": "YEAR",
}
+ AFTER_HAVING_MODIFIER_TRANSFORMS = {
+ "cluster": lambda self, e: self.sql(e, "cluster"),
+ "distribute": lambda self, e: self.sql(e, "distribute"),
+ "qualify": lambda self, e: self.sql(e, "qualify"),
+ "sort": lambda self, e: self.sql(e, "sort"),
+ "windows": lambda self, e: (
+ self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True)
+ if e.args.get("windows")
+ else ""
+ ),
+ }
+
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@@ -361,6 +384,7 @@ class Generator(metaclass=_Generator):
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.BackupProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
@@ -380,8 +404,10 @@ class Generator(metaclass=_Generator):
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
+ exp.GlobalProperty: exp.Properties.Location.POST_CREATE,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.IcebergProperty: exp.Properties.Location.POST_CREATE,
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
@@ -414,6 +440,8 @@ class Generator(metaclass=_Generator):
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION,
+ exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
@@ -423,6 +451,8 @@ class Generator(metaclass=_Generator):
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
+ exp.UnloggedProperty: exp.Properties.Location.POST_CREATE,
+ exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
@@ -441,6 +471,7 @@ class Generator(metaclass=_Generator):
exp.Insert,
exp.Join,
exp.Select,
+ exp.Union,
exp.Update,
exp.Where,
exp.With,
@@ -626,7 +657,7 @@ class Generator(metaclass=_Generator):
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return (
f"{self.sep()}{comments_sql}{sql}"
- if sql[0].isspace()
+ if not sql or sql[0].isspace()
else f"{comments_sql}{self.sep()}{sql}"
)
@@ -869,7 +900,9 @@ class Generator(metaclass=_Generator):
this = f" {this}" if this else ""
index_type = expression.args.get("index_type")
index_type = f" USING {index_type}" if index_type else ""
- return f"UNIQUE{this}{index_type}"
+ on_conflict = self.sql(expression, "on_conflict")
+ on_conflict = f" {on_conflict}" if on_conflict else ""
+ return f"UNIQUE{this}{index_type}{on_conflict}"
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
return self.sql(expression, "this")
@@ -961,6 +994,31 @@ class Generator(metaclass=_Generator):
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
+ def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str:
+ start = self.sql(expression, "start")
+ start = f"START WITH {start}" if start else ""
+ increment = self.sql(expression, "increment")
+ increment = f" INCREMENT BY {increment}" if increment else ""
+ minvalue = self.sql(expression, "minvalue")
+ minvalue = f" MINVALUE {minvalue}" if minvalue else ""
+ maxvalue = self.sql(expression, "maxvalue")
+ maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
+ owned = self.sql(expression, "owned")
+ owned = f" OWNED BY {owned}" if owned else ""
+
+ cache = expression.args.get("cache")
+ if cache is None:
+ cache_str = ""
+ elif cache is True:
+ cache_str = " CACHE"
+ else:
+ cache_str = f" CACHE {cache}"
+
+ options = self.expressions(expression, key="options", flat=True, sep=" ")
+ options = f" {options}" if options else ""
+
+ return f"{start}{increment}{minvalue}{maxvalue}{cache_str}{options}{owned}".lstrip()
+
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
@@ -968,8 +1026,9 @@ class Generator(metaclass=_Generator):
return f"{shallow}{keyword} {this}"
def describe_sql(self, expression: exp.Describe) -> str:
- extended = " EXTENDED" if expression.args.get("extended") else ""
- return f"DESCRIBE{extended} {self.sql(expression, 'this')}"
+ style = expression.args.get("style")
+ style = f" {style}" if style else ""
+ return f"DESCRIBE{style} {self.sql(expression, 'this')}"
def heredoc_sql(self, expression: exp.Heredoc) -> str:
tag = self.sql(expression, "tag")
@@ -993,7 +1052,14 @@ class Generator(metaclass=_Generator):
def cte_sql(self, expression: exp.CTE) -> str:
alias = self.sql(expression, "alias")
- return f"{alias} AS {self.wrap(expression)}"
+
+ materialized = expression.args.get("materialized")
+ if materialized is False:
+ materialized = "NOT MATERIALIZED "
+ elif materialized:
+ materialized = "MATERIALIZED "
+
+ return f"{alias} AS {materialized or ''}{self.wrap(expression)}"
def tablealias_sql(self, expression: exp.TableAlias) -> str:
alias = self.sql(expression, "this")
@@ -1044,7 +1110,7 @@ class Generator(metaclass=_Generator):
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
def rawstring_sql(self, expression: exp.RawString) -> str:
- string = self.escape_str(expression.this.replace("\\", "\\\\"))
+ string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
@@ -1114,6 +1180,8 @@ class Generator(metaclass=_Generator):
def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" ({expressions})" if expressions else ""
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
@@ -1121,15 +1189,10 @@ class Generator(metaclass=_Generator):
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
purge = " PURGE" if expression.args.get("purge") else ""
- return (
- f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
- )
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{expressions}{cascade}{constraints}{purge}"
def except_sql(self, expression: exp.Except) -> str:
- return self.prepend_ctes(
- expression,
- self.set_operation(expression, self.except_op(expression)),
- )
+ return self.set_operations(expression)
def except_op(self, expression: exp.Except) -> str:
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
@@ -1163,17 +1226,9 @@ class Generator(metaclass=_Generator):
return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */"
- def index_sql(self, expression: exp.Index) -> str:
- unique = "UNIQUE " if expression.args.get("unique") else ""
- primary = "PRIMARY " if expression.args.get("primary") else ""
- amp = "AMP " if expression.args.get("amp") else ""
- name = self.sql(expression, "this")
- name = f"{name} " if name else ""
- table = self.sql(expression, "table")
- table = f"{self.INDEX_ON} {table}" if table else ""
+ def indexparameters_sql(self, expression: exp.IndexParameters) -> str:
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
- index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
@@ -1182,7 +1237,26 @@ class Generator(metaclass=_Generator):
include = self.expressions(expression, key="include", flat=True)
if include:
include = f" INCLUDE ({include})"
- return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
+ with_storage = self.expressions(expression, key="with_storage", flat=True)
+ with_storage = f" WITH ({with_storage})" if with_storage else ""
+ tablespace = self.sql(expression, "tablespace")
+ tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else ""
+
+ return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}"
+
+ def index_sql(self, expression: exp.Index) -> str:
+ unique = "UNIQUE " if expression.args.get("unique") else ""
+ primary = "PRIMARY " if expression.args.get("primary") else ""
+ amp = "AMP " if expression.args.get("amp") else ""
+ name = self.sql(expression, "this")
+ name = f"{name} " if name else ""
+ table = self.sql(expression, "table")
+ table = f"{self.INDEX_ON} {table}" if table else ""
+
+ index = "INDEX " if not table else ""
+
+ params = self.sql(expression, "params")
+ return f"{unique}{primary}{amp}{index}{name}{table}{params}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
@@ -1371,15 +1445,9 @@ class Generator(metaclass=_Generator):
no = " NO" if no else ""
concurrent = expression.args.get("concurrent")
concurrent = " CONCURRENT" if concurrent else ""
-
- for_ = ""
- if expression.args.get("for_all"):
- for_ = " FOR ALL"
- elif expression.args.get("for_insert"):
- for_ = " FOR INSERT"
- elif expression.args.get("for_none"):
- for_ = " FOR NONE"
- return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
+ target = self.sql(expression, "target")
+ target = f" {target}" if target else ""
+ return f"WITH{no}{concurrent} ISOLATED LOADING{target}"
def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
if isinstance(expression.this, list):
@@ -1437,6 +1505,7 @@ class Generator(metaclass=_Generator):
return f"{sql})"
def insert_sql(self, expression: exp.Insert) -> str:
+ hint = self.sql(expression, "hint")
overwrite = expression.args.get("overwrite")
if isinstance(expression.this, exp.Directory):
@@ -1447,7 +1516,9 @@ class Generator(metaclass=_Generator):
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
ignore = " IGNORE" if expression.args.get("ignore") else ""
-
+ is_function = expression.args.get("is_function")
+ if is_function:
+ this = f"{this} FUNCTION"
this = f"{this} {self.sql(expression, 'this')}"
exists = " IF EXISTS" if expression.args.get("exists") else ""
@@ -1457,23 +1528,21 @@ class Generator(metaclass=_Generator):
where = self.sql(expression, "where")
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
- conflict = self.sql(expression, "conflict")
+ on_conflict = self.sql(expression, "conflict")
+ on_conflict = f" {on_conflict}" if on_conflict else ""
by_name = " BY NAME" if expression.args.get("by_name") else ""
returning = self.sql(expression, "returning")
if self.RETURNING_END:
- expression_sql = f"{expression_sql}{conflict}{returning}"
+ expression_sql = f"{expression_sql}{on_conflict}{returning}"
else:
- expression_sql = f"{returning}{expression_sql}{conflict}"
+ expression_sql = f"{returning}{expression_sql}{on_conflict}"
- sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
+ sql = f"INSERT{hint}{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
- return self.prepend_ctes(
- expression,
- self.set_operation(expression, self.intersect_op(expression)),
- )
+ return self.set_operations(expression)
def intersect_op(self, expression: exp.Intersect) -> str:
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
@@ -1496,33 +1565,36 @@ class Generator(metaclass=_Generator):
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
+
constraint = self.sql(expression, "constraint")
- if constraint:
- constraint = f"ON CONSTRAINT {constraint}"
- key = self.expressions(expression, key="key", flat=True)
- do = "" if expression.args.get("duplicate") else " DO "
- nothing = "NOTHING" if expression.args.get("nothing") else ""
+ constraint = f" ON CONSTRAINT {constraint}" if constraint else ""
+
+ conflict_keys = self.expressions(expression, key="conflict_keys", flat=True)
+ conflict_keys = f"({conflict_keys}) " if conflict_keys else " "
+ action = self.sql(expression, "action")
+
expressions = self.expressions(expression, flat=True)
- set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
if expressions:
- expressions = f"UPDATE {set_keyword}{expressions}"
- return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
+ set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
+ expressions = f" {set_keyword}{expressions}"
+
+ return f"{conflict}{constraint}{conflict_keys}{action}{expressions}"
def returning_sql(self, expression: exp.Returning) -> str:
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
- fields = expression.args.get("fields")
+ fields = self.sql(expression, "fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
- escaped = expression.args.get("escaped")
+ escaped = self.sql(expression, "escaped")
escaped = f" ESCAPED BY {escaped}" if escaped else ""
- items = expression.args.get("collection_items")
+ items = self.sql(expression, "collection_items")
items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
- keys = expression.args.get("map_keys")
+ keys = self.sql(expression, "map_keys")
keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
- lines = expression.args.get("lines")
+ lines = self.sql(expression, "lines")
lines = f" LINES TERMINATED BY {lines}" if lines else ""
- null = expression.args.get("null")
+ null = self.sql(expression, "null")
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
@@ -1563,7 +1635,9 @@ class Generator(metaclass=_Generator):
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
- joins = self.expressions(expression, key="joins", sep="", skip_first=True)
+ joins = self.indent(
+ self.expressions(expression, key="joins", sep="", flat=True), skip_first=True
+ )
laterals = self.expressions(expression, key="laterals", sep="")
file_format = self.sql(expression, "format")
@@ -1673,9 +1747,11 @@ class Generator(metaclass=_Generator):
sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
return self.prepend_ctes(expression, sql)
- def values_sql(self, expression: exp.Values) -> str:
+ def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
+ values_as_table = values_as_table and self.VALUES_AS_TABLE
+
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
- if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join):
+ if values_as_table or not expression.find_ancestor(exp.From, exp.Join):
args = self.expressions(expression)
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
@@ -1769,8 +1845,9 @@ class Generator(metaclass=_Generator):
def connect_sql(self, expression: exp.Connect) -> str:
start = self.sql(expression, "start")
start = self.seg(f"START WITH {start}") if start else ""
+ nocycle = " NOCYCLE" if expression.args.get("nocycle") else ""
connect = self.sql(expression, "connect")
- connect = self.seg(f"CONNECT BY {connect}")
+ connect = self.seg(f"CONNECT BY{nocycle} {connect}")
return start + connect
def prior_sql(self, expression: exp.Prior) -> str:
@@ -1793,6 +1870,8 @@ class Generator(metaclass=_Generator):
)
if op
)
+ match_cond = self.sql(expression, "match_condition")
+ match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else ""
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@@ -1816,7 +1895,7 @@ class Generator(metaclass=_Generator):
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
- return f"{self.seg(op_sql)} {this_sql}{on_sql}"
+ return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
@@ -1919,13 +1998,17 @@ class Generator(metaclass=_Generator):
text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}"
return text
- def escape_str(self, text: str) -> str:
- text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
- if self.dialect.INVERSE_ESCAPE_SEQUENCES:
- text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
- elif self.pretty:
+ def escape_str(self, text: str, escape_backslash: bool = True) -> str:
+ if self.dialect.ESCAPED_SEQUENCES:
+ to_escaped = self.dialect.ESCAPED_SEQUENCES
+ text = "".join(
+ to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text
+ )
+
+ if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
- return text
+
+ return text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
@@ -2016,7 +2099,7 @@ class Generator(metaclass=_Generator):
self.unsupported(
f"'{nulls_sort_change.strip()}' translation not supported with positional ordering"
)
- else:
+ elif not isinstance(expression.this, exp.Rand):
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
@@ -2059,24 +2142,13 @@ class Generator(metaclass=_Generator):
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
- limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
-
- # If the limit is generated as TOP, we need to ensure it's not generated twice
- with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
+ limit = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count")))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
- fetch = isinstance(limit, exp.Fetch)
-
- offset_limit_modifiers = (
- self.offset_limit_modifiers(expression, fetch, limit)
- if with_offset_limit_modifiers
- else []
- )
-
options = self.expressions(expression, key="options")
if options:
options = f" OPTION{self.wrap(options)}"
@@ -2091,9 +2163,9 @@ class Generator(metaclass=_Generator):
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
- *self.after_having_modifiers(expression),
+ *[gen(self, expression) for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values()],
self.sql(expression, "order"),
- *offset_limit_modifiers,
+ *self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
*self.after_limit_modifiers(expression),
options,
sep="",
@@ -2110,19 +2182,6 @@ class Generator(metaclass=_Generator):
self.sql(limit) if fetch else self.sql(expression, "offset"),
]
- def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
- return [
- self.sql(expression, "qualify"),
- (
- self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
- if expression.args.get("windows")
- else ""
- ),
- self.sql(expression, "distribute"),
- self.sql(expression, "sort"),
- self.sql(expression, "cluster"),
- ]
-
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
locks = self.expressions(expression, key="locks", sep=" ")
locks = f" {locks}" if locks else ""
@@ -2137,12 +2196,13 @@ class Generator(metaclass=_Generator):
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = self.sql(expression, "kind")
+
limit = expression.args.get("limit")
- top = (
- self.limit_sql(limit, top=True)
- if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP
- else ""
- )
+ if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP:
+ top = self.limit_sql(limit, top=True)
+ limit.pop()
+ else:
+ top = ""
expressions = self.expressions(expression)
@@ -2220,7 +2280,7 @@ class Generator(metaclass=_Generator):
return f"@@{kind}{this}"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
- return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?"
+ return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.this else "?"
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
@@ -2236,11 +2296,32 @@ class Generator(metaclass=_Generator):
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
+ def set_operations(self, expression: exp.Union) -> str:
+ sqls: t.List[str] = []
+ stack: t.List[t.Union[str, exp.Expression]] = [expression]
+
+ while stack:
+ node = stack.pop()
+
+ if isinstance(node, exp.Union):
+ stack.append(node.expression)
+ stack.append(
+ self.maybe_comment(
+ getattr(self, f"{node.key}_op")(node),
+ expression=node.this,
+ comments=node.comments,
+ )
+ )
+ stack.append(node.this)
+ else:
+ sqls.append(self.sql(node))
+
+ this = self.sep().join(sqls)
+ this = self.query_modifiers(expression, this)
+ return self.prepend_ctes(expression, this)
+
def union_sql(self, expression: exp.Union) -> str:
- return self.prepend_ctes(
- expression,
- self.set_operation(expression, self.union_op(expression)),
- )
+ return self.set_operations(expression)
def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
@@ -2345,8 +2426,10 @@ class Generator(metaclass=_Generator):
def any_sql(self, expression: exp.Any) -> str:
this = self.sql(expression, "this")
- if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
- this = self.wrap(this)
+ if isinstance(expression.this, (*exp.UNWRAPPED_QUERIES, exp.Paren)):
+ if isinstance(expression.this, exp.UNWRAPPED_QUERIES):
+ this = self.wrap(this)
+ return f"ANY{this}"
return f"ANY {this}"
def exists_sql(self, expression: exp.Exists) -> str:
@@ -2632,13 +2715,8 @@ class Generator(metaclass=_Generator):
return self.func(self.sql(expression, "this"), *expression.expressions)
def paren_sql(self, expression: exp.Paren) -> str:
- if isinstance(expression.unnest(), exp.Select):
- sql = self.wrap(expression)
- else:
- sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
- sql = f"({sql}{self.seg(')', sep='')}"
-
- return self.prepend_ctes(expression, sql)
+ sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
+ return f"({sql}{self.seg(')', sep='')}"
def neg_sql(self, expression: exp.Neg) -> str:
# This makes sure we don't convert "- - 5" to "--5", which is a comment
@@ -2686,23 +2764,55 @@ class Generator(metaclass=_Generator):
def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
- def and_sql(self, expression: exp.And) -> str:
- return self.connector_sql(expression, "AND")
+ def and_sql(
+ self, expression: exp.And, stack: t.Optional[t.List[str | exp.Expression]] = None
+ ) -> str:
+ return self.connector_sql(expression, "AND", stack)
- def xor_sql(self, expression: exp.Xor) -> str:
- return self.connector_sql(expression, "XOR")
+ def or_sql(
+ self, expression: exp.Or, stack: t.Optional[t.List[str | exp.Expression]] = None
+ ) -> str:
+ return self.connector_sql(expression, "OR", stack)
- def connector_sql(self, expression: exp.Connector, op: str) -> str:
- if not self.pretty:
- return self.binary(expression, op)
+ def xor_sql(
+ self, expression: exp.Xor, stack: t.Optional[t.List[str | exp.Expression]] = None
+ ) -> str:
+ return self.connector_sql(expression, "XOR", stack)
- sqls = tuple(
- self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
- for i, e in enumerate(expression.flatten(unnest=False))
- )
+ def connector_sql(
+ self,
+ expression: exp.Connector,
+ op: str,
+ stack: t.Optional[t.List[str | exp.Expression]] = None,
+ ) -> str:
+ if stack is not None:
+ if expression.expressions:
+ stack.append(self.expressions(expression, sep=f" {op} "))
+ else:
+ stack.append(expression.right)
+ if expression.comments:
+ for comment in expression.comments:
+ op += f" /*{self.pad_comment(comment)}*/"
+ stack.extend((op, expression.left))
+ return op
+
+ stack = [expression]
+ sqls: t.List[str] = []
+ ops = set()
+
+ while stack:
+ node = stack.pop()
+ if isinstance(node, exp.Connector):
+ ops.add(getattr(self, f"{node.key}_sql")(node, stack))
+ else:
+ sql = self.sql(node)
+ if sqls and sqls[-1] in ops:
+ sqls[-1] += f" {sql}"
+ else:
+ sqls.append(sql)
- sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
- return f"{sep}{op} ".join(sqls)
+ sep = "\n" if self.pretty and self.text_width(sqls) > self.max_text_width else " "
+ return sep.join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
return self.binary(expression, "&")
@@ -2727,7 +2837,9 @@ class Generator(metaclass=_Generator):
format_sql = f" FORMAT {format_sql}" if format_sql else ""
to_sql = self.sql(expression, "to")
to_sql = f" {to_sql}" if to_sql else ""
- return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
+ action = self.sql(expression, "action")
+ action = f" {action}" if action else ""
+ return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql}{action})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@@ -2817,7 +2929,7 @@ class Generator(metaclass=_Generator):
# Remove db from tables
expression = expression.transform(
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
- )
+ ).assert_is(exp.RenameTable)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
@@ -2889,30 +3001,6 @@ class Generator(metaclass=_Generator):
kind = "MAX" if expression.args.get("max") else "MIN"
return f"{this_sql} HAVING {kind} {expression_sql}"
- def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
- if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
- # The first modifier here will be the one closest to the AggFunc's arg
- mods = sorted(
- expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
- key=lambda x: 0
- if isinstance(x, exp.HavingMax)
- else (1 if isinstance(x, exp.Order) else 2),
- )
-
- if mods:
- mod = mods[0]
- this = expression.__class__(this=mod.this.copy())
- this.meta["inline"] = True
- mod.this.replace(this)
- return self.sql(expression.this)
-
- agg_func = expression.find(exp.AggFunc)
-
- if agg_func:
- return self.sql(agg_func)[:-1] + f" {text})"
-
- return f"{self.sql(expression, 'this')} {text}"
-
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
@@ -2933,9 +3021,7 @@ class Generator(metaclass=_Generator):
r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))
if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
- if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
- *exp.DataType.FLOAT_TYPES
- ):
+ if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type(*exp.DataType.REAL_TYPES):
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))
elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
@@ -3019,9 +3105,6 @@ class Generator(metaclass=_Generator):
def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
return self.binary(expression, "IS DISTINCT FROM")
- def or_sql(self, expression: exp.Or) -> str:
- return self.connector_sql(expression, "OR")
-
def slice_sql(self, expression: exp.Slice) -> str:
return self.binary(expression, ":")
@@ -3035,8 +3118,13 @@ class Generator(metaclass=_Generator):
this = expression.this
expr = expression.expression
- if not self.dialect.LOG_BASE_FIRST:
+ if self.dialect.LOG_BASE_FIRST is False:
this, expr = expr, this
+ elif self.dialect.LOG_BASE_FIRST is None and expr:
+ if this.name in ("2", "10"):
+ return self.func(f"LOG{this.name}", expr)
+
+ self.unsupported(f"Unsupported logarithm with base {self.sql(this)}")
return self.func("LOG", this, expr)
@@ -3088,11 +3176,16 @@ class Generator(metaclass=_Generator):
def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
- def format_time(self, expression: exp.Expression) -> t.Optional[str]:
+ def format_time(
+ self,
+ expression: exp.Expression,
+ inverse_time_mapping: t.Optional[t.Dict[str, str]] = None,
+ inverse_time_trie: t.Optional[t.Dict] = None,
+ ) -> t.Optional[str]:
return format_time(
self.sql(expression, "format"),
- self.dialect.INVERSE_TIME_MAPPING,
- self.dialect.INVERSE_TIME_TRIE,
+ inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING,
+ inverse_time_trie or self.dialect.INVERSE_TIME_TRIE,
)
def expressions(
@@ -3117,8 +3210,11 @@ class Generator(metaclass=_Generator):
num_sqls = len(expressions)
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
- pad = " " * self.pad
- stripped_sep = sep.strip()
+ if self.pretty:
+ if self.leading_comma:
+ pad = " " * len(sep)
+ else:
+ stripped_sep = sep.strip()
result_sqls = []
for i, e in enumerate(expressions):
@@ -3154,13 +3250,6 @@ class Generator(metaclass=_Generator):
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
- def set_operation(self, expression: exp.Union, op: str) -> str:
- this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments)
- op = self.seg(op)
- return self.query_modifiers(
- expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
- )
-
def tag_sql(self, expression: exp.Tag) -> str:
return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
@@ -3227,6 +3316,18 @@ class Generator(metaclass=_Generator):
return self.sql(exp.cast(expression.this, "text"))
+ def tonumber_sql(self, expression: exp.ToNumber) -> str:
+ if not self.SUPPORTS_TO_NUMBER:
+ self.unsupported("Unsupported TO_NUMBER function")
+ return self.sql(exp.cast(expression.this, "double"))
+
+ fmt = expression.args.get("format")
+ if not fmt:
+ self.unsupported("Conversion format is required for TO_NUMBER")
+ return self.sql(exp.cast(expression.this, "double"))
+
+ return self.func("TO_NUMBER", expression.this, fmt)
+
def dictproperty_sql(self, expression: exp.DictProperty) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
@@ -3320,11 +3421,11 @@ class Generator(metaclass=_Generator):
this = f" {this}" if this else ""
index_type = self.sql(expression, "index_type")
index_type = f" USING {index_type}" if index_type else ""
- schema = self.sql(expression, "schema")
- schema = f" {schema}" if schema else ""
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" ({expressions})" if expressions else ""
options = self.expressions(expression, key="options", sep=" ")
options = f" {options}" if options else ""
- return f"{kind}{this}{index_type}{schema}{options}"
+ return f"{kind}{this}{index_type}{expressions}{options}"
def nvl2_sql(self, expression: exp.Nvl2) -> str:
if self.NVL2_SUPPORTED:
@@ -3396,6 +3497,13 @@ class Generator(metaclass=_Generator):
return self.sql(exp.cast(this, "time"))
+ def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str:
+ this = expression.this
+ if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP):
+ return self.sql(this)
+
+ return self.sql(exp.cast(this, "timestamp"))
+
def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
this = expression.this
time_format = self.format_time(expression)
@@ -3430,6 +3538,13 @@ class Generator(metaclass=_Generator):
return self.func("LAST_DAY", expression.this)
+ def dateadd_sql(self, expression: exp.DateAdd) -> str:
+ from sqlglot.dialects.dialect import unit_to_str
+
+ return self.func(
+ "DATE_ADD", expression.this, expression.expression, unit_to_str(expression)
+ )
+
def arrayany_sql(self, expression: exp.ArrayAny) -> str:
if self.CAN_IMPLEMENT_ARRAY_ANY:
filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression)
@@ -3445,30 +3560,6 @@ class Generator(metaclass=_Generator):
return self.function_fallback_sql(expression)
- def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
- this = expression.this
- if isinstance(this, exp.JSONPathWildcard):
- this = self.json_path_part(this)
- return f".{this}" if this else ""
-
- if exp.SAFE_IDENTIFIER_RE.match(this):
- return f".{this}"
-
- this = self.json_path_part(this)
- return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
-
- def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
- this = self.json_path_part(expression.this)
- return f"[{this}]" if this else ""
-
- def _simplify_unless_literal(self, expression: E) -> E:
- if not isinstance(expression, exp.Literal):
- from sqlglot.optimizer.simplify import simplify
-
- expression = simplify(expression, dialect=self.dialect)
-
- return expression
-
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
expression.set("is_end_exclusive", None)
return self.function_fallback_sql(expression)
@@ -3477,7 +3568,9 @@ class Generator(metaclass=_Generator):
expression.set(
"expressions",
[
- exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
+ exp.alias_(e.expression, e.name if e.this.is_string else e.this)
+ if isinstance(e, exp.PropertyEQ)
+ else e
for e in expression.expressions
],
)
@@ -3553,3 +3646,51 @@ class Generator(metaclass=_Generator):
transformed = cast(this=value, to=to, safe=safe)
return self.sql(transformed)
+
+ def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
+ this = expression.this
+ if isinstance(this, exp.JSONPathWildcard):
+ this = self.json_path_part(this)
+ return f".{this}" if this else ""
+
+ if exp.SAFE_IDENTIFIER_RE.match(this):
+ return f".{this}"
+
+ this = self.json_path_part(this)
+ return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
+
+ def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
+ this = self.json_path_part(expression.this)
+ return f"[{this}]" if this else ""
+
+ def _simplify_unless_literal(self, expression: E) -> E:
+ if not isinstance(expression, exp.Literal):
+ from sqlglot.optimizer.simplify import simplify
+
+ expression = simplify(expression, dialect=self.dialect)
+
+ return expression
+
+ def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
+ if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
+ # The first modifier here will be the one closest to the AggFunc's arg
+ mods = sorted(
+ expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
+ key=lambda x: 0
+ if isinstance(x, exp.HavingMax)
+ else (1 if isinstance(x, exp.Order) else 2),
+ )
+
+ if mods:
+ mod = mods[0]
+ this = expression.__class__(this=mod.this.copy())
+ this.meta["inline"] = True
+ mod.this.replace(this)
+ return self.sql(expression.this)
+
+ agg_func = expression.find(exp.AggFunc)
+
+ if agg_func:
+ return self.sql(agg_func)[:-1] + f" {text})"
+
+ return f"{self.sql(expression, 'this')} {text}"
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 0d4547f..0187c51 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -181,7 +181,7 @@ def apply_index_offset(
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
- expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
+ expression = simplify(expression + offset)
return [expression]
return expressions
@@ -204,13 +204,13 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
The transformed expression.
"""
while True:
- for n, *_ in reversed(tuple(expression.walk())):
+ for n in reversed(tuple(expression.walk())):
n._hash = hash(n)
start = hash(expression)
expression = func(expression)
- for n, *_ in expression.walk():
+ for n in expression.walk():
n._hash = None
if start == hash(expression):
break
@@ -317,8 +317,16 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
def is_int(text: str) -> bool:
+ return is_type(text, int)
+
+
+def is_float(text: str) -> bool:
+ return is_type(text, float)
+
+
+def is_type(text: str, target_type: t.Type) -> bool:
try:
- int(text)
+ target_type(text)
return True
except ValueError:
return False
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index eb428dc..c91bb36 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -28,10 +28,7 @@ class Node:
yield self
for d in self.downstream:
- if isinstance(d, Node):
- yield from d.walk()
- else:
- yield d
+ yield from d.walk()
def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
nodes = {}
@@ -71,8 +68,10 @@ def lineage(
column: str | exp.Column,
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
- sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
+ sources: t.Optional[t.Mapping[str, str | exp.Query]] = None,
dialect: DialectType = None,
+ scope: t.Optional[Scope] = None,
+ trim_selects: bool = True,
**kwargs,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@@ -83,6 +82,8 @@ def lineage(
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
dialect: The dialect of input SQL.
+ scope: A pre-created scope to use instead.
+ trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
**kwargs: Qualification optimizer kwargs.
Returns:
@@ -99,14 +100,15 @@ def lineage(
dialect=dialect,
)
- qualified = qualify.qualify(
- expression,
- dialect=dialect,
- schema=schema,
- **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
- )
+ if not scope:
+ expression = qualify.qualify(
+ expression,
+ dialect=dialect,
+ schema=schema,
+ **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
+ )
- scope = build_scope(qualified)
+ scope = build_scope(expression)
if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
@@ -114,7 +116,7 @@ def lineage(
if not any(select.alias_or_name == column for select in scope.expression.selects):
raise SqlglotError(f"Cannot find column '{column}' in query.")
- return to_node(column, scope, dialect)
+ return to_node(column, scope, dialect, trim_selects=trim_selects)
def to_node(
@@ -125,6 +127,7 @@ def to_node(
upstream: t.Optional[Node] = None,
source_name: t.Optional[str] = None,
reference_node_name: t.Optional[str] = None,
+ trim_selects: bool = True,
) -> Node:
source_names = {
dt.alias: dt.comments[0].split()[1]
@@ -143,6 +146,17 @@ def to_node(
)
)
+ if isinstance(scope.expression, exp.Subquery):
+ for source in scope.subquery_scopes:
+ return to_node(
+ column,
+ scope=source,
+ dialect=dialect,
+ upstream=upstream,
+ source_name=source_name,
+ reference_node_name=reference_node_name,
+ trim_selects=trim_selects,
+ )
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
@@ -170,11 +184,12 @@ def to_node(
upstream=upstream,
source_name=source_name,
reference_node_name=reference_node_name,
+ trim_selects=trim_selects,
)
return upstream
- if isinstance(scope.expression, exp.Select):
+ if trim_selects and isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
@@ -206,7 +221,13 @@ def to_node(
continue
for name in subquery.named_selects:
- to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
+ to_node(
+ name,
+ scope=subquery_scope,
+ dialect=dialect,
+ upstream=node,
+ trim_selects=trim_selects,
+ )
# if the select is a star add all scope sources as downstreams
if select.is_star:
@@ -237,6 +258,7 @@ def to_node(
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=selected_node.name if selected_node else None,
+ trim_selects=trim_selects,
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 81b1ee6..c85ef1c 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -168,8 +168,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Exp,
exp.Ln,
exp.Log,
- exp.Log2,
- exp.Log10,
exp.Pow,
exp.Quantile,
exp.Round,
@@ -266,26 +264,30 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
+ exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
+ e, exp.DataType.build("ARRAY<DATE>")
+ ),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
- exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
- exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
+ exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
+ exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
- exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.VarMap: lambda self, e: self._annotate_map(e),
}
NESTED_TYPES = {
@@ -358,6 +360,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
+ elif isinstance(source.expression, exp.Unnest):
+ values = [source.expression]
else:
values = source.expression.expressions[0].expressions
@@ -408,7 +412,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
def _annotate_args(self, expression: E) -> E:
- for _, value in expression.iter_expressions():
+ for value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
@@ -425,23 +429,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
- if type1_value in self.NESTED_TYPES:
- return type1
- if type2_value in self.NESTED_TYPES:
- return type2
-
- return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
-
- # Note: the following "no_type_check" decorators were added because mypy was yelling due
- # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
- # This is a known mypy issue: https://github.com/python/mypy/issues/3004
+ return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
- @t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left, right = expression.left, expression.right
- left_type, right_type = left.type.this, right.type.this
+ left_type, right_type = left.type.this, right.type.this # type: ignore
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@@ -462,7 +456,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
- @t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
@@ -473,7 +466,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
- @t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
self._set_type(expression, exp.DataType.Type.VARCHAR)
@@ -484,33 +476,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
- @t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
self._set_type(expression, target_type)
return self._annotate_args(expression)
@t.no_type_check
- def _annotate_struct_value(
- self, expression: exp.Expression
- ) -> t.Optional[exp.DataType] | exp.ColumnDef:
- alias = expression.args.get("alias")
- if alias:
- return exp.ColumnDef(this=alias.copy(), kind=expression.type)
-
- # Case: key = value or key := value
- if expression.expression:
- return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
-
- return expression.type
-
- @t.no_type_check
def _annotate_by_args(
self,
expression: E,
*args: str,
promote: bool = False,
array: bool = False,
- struct: bool = False,
) -> E:
self._annotate_args(expression)
@@ -546,16 +522,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
),
)
- if struct:
- self._set_type(
- expression,
- exp.DataType(
- this=exp.DataType.Type.STRUCT,
- expressions=[self._annotate_struct_value(expr) for expr in expressions],
- nested=True,
- ),
- )
-
return expression
def _annotate_timeunit(
@@ -605,6 +571,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.BIGINT)
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
+ if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
+ self._set_type(
+ expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
+ )
return expression
@@ -631,3 +601,68 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
return expression
+
+ def _annotate_struct_value(
+ self, expression: exp.Expression
+ ) -> t.Optional[exp.DataType] | exp.ColumnDef:
+ alias = expression.args.get("alias")
+ if alias:
+ return exp.ColumnDef(this=alias.copy(), kind=expression.type)
+
+ # Case: key = value or key := value
+ if expression.expression:
+ return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
+
+ return expression.type
+
+ def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
+ self._annotate_args(expression)
+ self._set_type(
+ expression,
+ exp.DataType(
+ this=exp.DataType.Type.STRUCT,
+ expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
+ nested=True,
+ ),
+ )
+ return expression
+
+ @t.overload
+ def _annotate_map(self, expression: exp.Map) -> exp.Map: ...
+
+ @t.overload
+ def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
+
+ def _annotate_map(self, expression):
+ self._annotate_args(expression)
+
+ keys = expression.args.get("keys")
+ values = expression.args.get("values")
+
+ map_type = exp.DataType(this=exp.DataType.Type.MAP)
+ if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
+ key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
+ value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN
+
+ if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
+ map_type.set("expressions", [key_type, value_type])
+ map_type.set("nested", True)
+
+ self._set_type(expression, map_type)
+ return expression
+
+ def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
+ self._annotate_args(expression)
+
+ map_type = exp.DataType(this=exp.DataType.Type.MAP)
+ arg = expression.this
+ if arg.is_type(exp.DataType.Type.STRUCT):
+ for coldef in arg.type.expressions:
+ kind = coldef.kind
+ if kind != exp.DataType.Type.UNKNOWN:
+ map_type.set("expressions", [exp.DataType.build("varchar"), kind])
+ map_type.set("nested", True)
+ break
+
+ self._set_type(expression, map_type)
+ return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index 0aa8134..17a5089 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -16,16 +16,17 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
"""
- exp.replace_children(expression, canonicalize)
- expression = add_text_to_concat(expression)
- expression = replace_date_funcs(expression)
- expression = coerce_type(expression)
- expression = remove_redundant_casts(expression)
- expression = ensure_bools(expression, _replace_int_predicate)
- expression = remove_ascending_order(expression)
+ def _canonicalize(expression: exp.Expression) -> exp.Expression:
+ expression = add_text_to_concat(expression)
+ expression = replace_date_funcs(expression)
+ expression = coerce_type(expression)
+ expression = remove_redundant_casts(expression)
+ expression = ensure_bools(expression, _replace_int_predicate)
+ expression = remove_ascending_order(expression)
+ return expression
- return expression
+ return exp.replace_tree(expression, _canonicalize)
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
@@ -35,7 +36,11 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
+ if (
+ isinstance(node, (exp.Date, exp.TsOrDsToDate))
+ and not node.expressions
+ and not node.args.get("zone")
+ ):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
if not node.type:
@@ -121,15 +126,11 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a = _coerce_timeunit_arg(a, b.unit)
if (
a.type
- and a.type.this == exp.DataType.Type.DATE
+ and a.type.this in exp.DataType.TEMPORAL_TYPES
and b.type
- and b.type.this
- not in (
- exp.DataType.Type.DATE,
- exp.DataType.Type.INTERVAL,
- )
+ and b.type.this in exp.DataType.TEXT_TYPES
):
- _replace_cast(b, exp.DataType.Type.DATE)
+ _replace_cast(b, exp.DataType.Type.DATETIME)
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
@@ -169,7 +170,7 @@ def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
if isinstance(expression, exp.Coalesce):
- for _, child in expression.iter_expressions():
+ for child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(expression.neq(0))
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py
index 6f1865c..d2e876c 100644
--- a/sqlglot/optimizer/eliminate_ctes.py
+++ b/sqlglot/optimizer/eliminate_ctes.py
@@ -32,7 +32,7 @@ def eliminate_ctes(expression):
cte_node.pop()
# Pop the entire WITH clause if this is the last CTE
- if len(with_node.expressions) <= 0:
+ if with_node and len(with_node.expressions) <= 0:
with_node.pop()
# Decrement the ref count for all sources this CTE selects from
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index ea148cc..603f5df 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -214,6 +214,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
and not _is_recursive()
+ and not (inner_select.args.get("order") and outer_scope.is_union)
)
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 6bf877b..49b6c98 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
+ for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index f2a0990..eb84c00 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -53,10 +53,8 @@ def normalize_identifiers(expression, dialect=None):
if isinstance(expression, str):
expression = exp.parse_identifier(expression, dialect=dialect)
- def _normalize(node: E) -> E:
+ for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
if not node.meta.get("case_sensitive"):
- exp.replace_children(node, _normalize)
- node = dialect.normalize_identifier(node)
- return node
+ dialect.normalize_identifier(node)
- return _normalize(expression)
+ return expression
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 1c96e95..c82b8aa 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -82,13 +82,13 @@ def optimize(
**kwargs,
}
- expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
+ optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
- expression = rule(expression, **rule_kwargs)
+ optimized = rule(optimized, **rule_kwargs)
- return t.cast(exp.Expression, expression)
+ return optimized
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 12c3b89..18c9e83 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
pushdown_dnf(predicates, sources, scope_ref_count)
-def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
+def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
- for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
+ for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
if isinstance(node, exp.Join):
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
@@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
node.where(inner_predicate, copy=False)
-def pushdown_dnf(predicates, scope, scope_ref_count):
+def pushdown_dnf(predicates, sources, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
@@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
- nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
+ nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
if table not in nodes:
continue
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 53490bf..d97fd36 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -54,11 +54,15 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
- referenced_columns[right] = [
- right.expression.selects[i].alias_or_name
- for i, select in enumerate(left.expression.selects)
- if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
- ]
+ if scope.expression.args.get("by_name"):
+ referenced_columns[right] = referenced_columns[left]
+ else:
+ referenced_columns[right] = [
+ right.expression.selects[i].alias_or_name
+ for i, select in enumerate(left.expression.selects)
+ if SELECT_ALL in parent_selections
+ or select.alias_or_name in parent_selections
+ ]
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 233ffc9..027c32c 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -209,7 +209,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
- for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
+ for column in walk_in_scope(node, prune=lambda node: node.is_star):
if not isinstance(column, exp.Column):
continue
@@ -306,7 +306,7 @@ def _expand_positional_references(
else:
select = select.this
- if isinstance(select, exp.Literal):
+ if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
new_nodes.append(node)
else:
new_nodes.append(select.copy())
@@ -425,7 +425,7 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
- columns = columns or scope.outer_column_list
+ columns = columns or scope.outer_columns
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
@@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections = []
for i, (selection, aliased_column) in enumerate(
- itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
+ itertools.zip_longest(scope.expression.selects, scope.outer_columns)
):
if selection is None:
break
@@ -544,7 +544,7 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
"""Makes sure all identifiers that need to be quoted are quoted."""
return expression.transform(
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
- )
+ ) # type: ignore
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 214ac0a..a034bf5 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -56,7 +56,7 @@ def qualify_tables(
table.set("catalog", catalog)
if not isinstance(expression, exp.Query):
- for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
+ for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)
@@ -118,11 +118,11 @@ def qualify_tables(
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
else:
- for node, parent, _ in scope.walk():
+ for node in scope.walk():
if (
isinstance(node, exp.Table)
and not node.alias
- and isinstance(parent, (exp.From, exp.Join))
+ and isinstance(node.parent, (exp.From, exp.Join))
):
# Mutates the table by attaching an alias to it
alias(node, node.name, copy=False, table=True)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 443fa6c..073ced2 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -8,7 +8,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.helper import ensure_collection, find_new_name
+from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
@@ -38,11 +38,11 @@ class Scope:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
- defines a column list of it's alias of this scope, this is that list of columns.
+ outer_columns (list[str]): If this is a derived table or CTE, and the outer query
+ defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
- The inner query would have `["col1", "col2"]` for its `outer_column_list`
+ The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
@@ -58,7 +58,7 @@ class Scope:
self,
expression,
sources=None,
- outer_column_list=None,
+ outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
@@ -70,7 +70,7 @@ class Scope:
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
- self.outer_column_list = outer_column_list or []
+ self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
@@ -119,10 +119,11 @@ class Scope:
self._raw_columns = []
self._join_hints = []
- for node, parent, _ in self.walk(bfs=False):
+ for node in self.walk(bfs=False):
if node is self.expression:
continue
- elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+
+ if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
@@ -132,10 +133,8 @@ class Scope:
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
- elif (
- isinstance(node, exp.Subquery)
- and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
- and _is_derived_table(node)
+ elif _is_derived_table(node) and isinstance(
+ node.parent, (exp.From, exp.Join, exp.Subquery)
):
self._derived_tables.append(node)
elif isinstance(node, exp.UNWRAPPED_QUERIES):
@@ -438,11 +437,21 @@ class Scope:
Yields:
Scope: scope instances in depth-first-search post-order
"""
- for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
- ):
- yield from child_scope.traverse()
- yield self
+ stack = [self]
+ result = []
+ while stack:
+ scope = stack.pop()
+ result.append(scope)
+ stack.extend(
+ itertools.chain(
+ scope.cte_scopes,
+ scope.union_scopes,
+ scope.table_scopes,
+ scope.subquery_scopes,
+ )
+ )
+
+ yield from reversed(result)
def ref_count(self):
"""
@@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
- expression (exp.Expression): expression to traverse
+ expression: Expression to traverse
Returns:
- list[Scope]: scope instances
+ A list of the created scope instances
"""
- if isinstance(expression, exp.Query) or (
- isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
- ):
+ if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
+ # We ignore the DDL expression and build a scope for its query instead
+ ddl_with = expression.args.get("with")
+ expression = expression.expression
+
+ # If the DDL has CTEs attached, we need to add them to the query, or
+ # prepend them if the query itself already has CTEs attached to it
+ if ddl_with:
+ ddl_with.pop()
+ query_ctes = expression.ctes
+ if not query_ctes:
+ expression.set("with", ddl_with)
+ else:
+ expression.args["with"].set("recursive", ddl_with.recursive)
+ expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
+
+ if isinstance(expression, exp.Query):
return list(_traverse_scope(Scope(expression)))
return []
@@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
Build a scope tree.
Args:
- expression (exp.Expression): expression to build the scope tree for
+ expression: Expression to build the scope tree for.
+
Returns:
- Scope: root scope
+ The root scope
"""
- scopes = traverse_scope(expression)
- if scopes:
- return scopes[-1]
- return None
+ return seq_get(traverse_scope(expression), -1)
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
+ yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
+ return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
@@ -523,8 +546,6 @@ def _traverse_scope(scope):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
yield from _traverse_udtfs(scope)
- elif isinstance(scope.expression, exp.DDL):
- yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@@ -541,30 +562,38 @@ def _traverse_select(scope):
def _traverse_union(scope):
- yield from _traverse_ctes(scope)
+ prev_scope = None
+ union_scope_stack = [scope]
+ expression_stack = [scope.expression.right, scope.expression.left]
- # The last scope to be yield should be the top most scope
- left = None
- for left in _traverse_scope(
- scope.branch(
- scope.expression.left,
- outer_column_list=scope.outer_column_list,
- scope_type=ScopeType.UNION,
- )
- ):
- yield left
+ while expression_stack:
+ expression = expression_stack.pop()
+ union_scope = union_scope_stack[-1]
- right = None
- for right in _traverse_scope(
- scope.branch(
- scope.expression.right,
- outer_column_list=scope.outer_column_list,
+ new_scope = union_scope.branch(
+ expression,
+ outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
- ):
- yield right
- scope.union_scopes = [left, right]
+ if isinstance(expression, exp.Union):
+ yield from _traverse_ctes(new_scope)
+
+ union_scope_stack.append(new_scope)
+ expression_stack.extend([expression.right, expression.left])
+ continue
+
+ for scope in _traverse_scope(new_scope):
+ yield scope
+
+ if prev_scope:
+ union_scope_stack.pop()
+ union_scope.union_scopes = [prev_scope, scope]
+ prev_scope = union_scope
+
+ yield union_scope
+ else:
+ prev_scope = scope
def _traverse_ctes(scope):
@@ -588,7 +617,7 @@ def _traverse_ctes(scope):
scope.branch(
cte.this,
cte_sources=sources,
- outer_column_list=cte.alias_column_names,
+ outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
@@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
- return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
+ return isinstance(expression, exp.Subquery) and bool(
+ expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
+ )
def _traverse_tables(scope):
@@ -681,7 +712,7 @@ def _traverse_tables(scope):
scope.branch(
expression,
lateral_sources=lateral_sources,
- outer_column_list=expression.alias_column_names,
+ outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
@@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
sources = {}
for expression in expressions:
- if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
+ if _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
- outer_column_list=expression.alias_column_names,
+ outer_columns=expression.alias_column_names,
)
):
yield child_scope
@@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
scope.sources.update(sources)
-def _traverse_ddl(scope):
- yield from _traverse_ctes(scope)
-
- query_scope = scope.branch(
- scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
- )
- query_scope._collect()
- query_scope._ctes = scope.ctes + query_scope._ctes
-
- yield from _traverse_scope(query_scope)
-
-
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
@@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
# Whenever we set it to True, we exclude a subtree from traversal.
crossed_scope_boundary = False
- for node, parent, key in expression.walk(
- bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
+ for node in expression.walk(
+ bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
):
crossed_scope_boundary = False
- yield node, parent, key
+ yield node
if node is expression:
continue
if (
isinstance(node, exp.CTE)
or (
- isinstance(node, exp.Subquery)
- and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
- and _is_derived_table(node)
+ isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
+ and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
- or isinstance(node, exp.UDTF)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
@@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
Yields:
exp.Expression: nodes
"""
- for expression, *_ in walk_in_scope(expression, bfs=bfs):
+ for expression in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 2e43d21..d9a0d2b 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -9,19 +9,25 @@ from decimal import Decimal
import sqlglot
from sqlglot import Dialect, exp
-from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
+from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
DateTruncBinaryTransform = t.Callable[
- [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
+ [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
]
# Final means that an expression should not be simplified
FINAL = "final"
+# Value ranges for byte-sized signed/unsigned integers
+TINYINT_MIN = -128
+TINYINT_MAX = 127
+UTINYINT_MIN = 0
+UTINYINT_MAX = 255
+
class UnsupportedUnit(Exception):
pass
@@ -63,14 +69,14 @@ def simplify(
group.meta[FINAL] = True
for e in expression.selects:
- for node, *_ in e.walk():
+ for node in e.walk():
if node in groups:
e.meta[FINAL] = True
break
having = expression.args.get("having")
if having:
- for node, *_ in having.walk():
+ for node in having.walk():
if node in groups:
having.meta[FINAL] = True
break
@@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False):
r = extract_date(r)
if not r:
return None
+ # python won't compare date and datetime, but many engines will upcast
+ l, r = cast_as_datetime(l), cast_as_datetime(r)
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
@@ -431,7 +439,7 @@ def propagate_constants(expression, root=True):
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
- for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
+ for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
if isinstance(expr, exp.EQ):
l, r = expr.left, expr.right
@@ -544,7 +552,37 @@ def simplify_literals(expression, root=True):
return expression
+NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
+
+
+def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
+ if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
+ this = _simplify_integer_cast(expr.this)
+ else:
+ this = expr.this
+
+ if isinstance(expr, exp.Cast) and this.is_int:
+ num = int(this.name)
+
+ # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
+ # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
+ # engine-dependent
+ if (
+ TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
+ ) or (
+ UTINYINT_MIN <= num <= UTINYINT_MAX
+ and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
+ ):
+ return this
+
+ return expr
+
+
def _simplify_binary(expression, a, b):
+ if isinstance(expression, COMPARISONS):
+ a = _simplify_integer_cast(a)
+ b = _simplify_integer_cast(b)
+
if isinstance(expression, exp.Is):
if isinstance(b, exp.Not):
c = b.this
@@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b):
return exp.true() if not_ else exp.false()
if is_null(a):
return exp.false() if not_ else exp.true()
- elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
+ elif isinstance(expression, NULL_OK):
return None
elif is_null(a) or is_null(b):
return exp.null()
@@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif _is_date_literal(a) and isinstance(b, exp.Interval):
- a, b = extract_date(a), extract_interval(b)
- if a and b:
+ date, b = extract_date(a), extract_interval(b)
+ if date and b:
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
- return date_literal(a + b)
+ return date_literal(date + b, extract_type(a))
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
- return date_literal(a - b)
+ return date_literal(date - b, extract_type(a))
elif isinstance(a, exp.Interval) and _is_date_literal(b):
- a, b = extract_interval(a), extract_date(b)
+ a, date = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
- return date_literal(a + b)
+ return date_literal(a + date, extract_type(b))
elif _is_date_literal(a) and _is_date_literal(b):
if isinstance(expression, exp.Predicate):
a, b = extract_date(a), extract_date(b)
@@ -618,12 +656,16 @@ def simplify_parens(expression):
this = expression.this
parent = expression.parent
+ parent_is_predicate = isinstance(parent, exp.Predicate)
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
- or not isinstance(this, exp.Binary)
- or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
+ or (
+ not isinstance(this, exp.Binary)
+ and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
+ )
+ or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
@@ -632,24 +674,12 @@ def simplify_parens(expression):
return expression
-NONNULL_CONSTANTS = (
- exp.Literal,
- exp.Boolean,
-)
-
-CONSTANTS = (
- exp.Literal,
- exp.Boolean,
- exp.Null,
-)
-
-
def _is_nonnull_constant(expression: exp.Expression) -> bool:
- return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
+ return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
def _is_constant(expression: exp.Expression) -> bool:
- return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
+ return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
def simplify_coalesce(expression):
@@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti
return floor, floor + interval(unit)
-def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
+def _datetrunc_eq_expression(
+ left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
+) -> exp.Expression:
"""Get the logical expression for a date range"""
return exp.and_(
- left >= date_literal(drange[0]),
- left < date_literal(drange[1]),
+ left >= date_literal(drange[0], target_type),
+ left < date_literal(drange[1], target_type),
copy=False,
)
def _datetrunc_eq(
- left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
+ left: exp.Expression,
+ date: datetime.date,
+ unit: str,
+ dialect: Dialect,
+ target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
- return _datetrunc_eq_expression(left, drange)
+ return _datetrunc_eq_expression(left, drange, target_type)
def _datetrunc_neq(
- left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
+ left: exp.Expression,
+ date: datetime.date,
+ unit: str,
+ dialect: Dialect,
+ target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return exp.and_(
- left < date_literal(drange[0]),
- left >= date_literal(drange[1]),
+ left < date_literal(drange[0], target_type),
+ left >= date_literal(drange[1], target_type),
copy=False,
)
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
- exp.LT: lambda l, dt, u, d: l
- < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
- exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
- exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
- exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
+ exp.LT: lambda l, dt, u, d, t: l
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
+ exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
+ exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
+ exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
@@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
comparison = expression.__class__
if isinstance(expression, DATETRUNCS):
- date = extract_date(expression.this)
+ this = expression.this
+ trunc_type = extract_type(this)
+ date = extract_date(this)
if date and expression.unit:
- return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
+ return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
elif comparison not in DATETRUNC_COMPARISONS:
return expression
@@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
l = t.cast(exp.DateTrunc, l)
+ trunc_arg = l.this
unit = l.unit.name.lower()
date = extract_date(r)
if not date:
return expression
- return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
- elif isinstance(expression, exp.In):
+ return (
+ DATETRUNC_BINARY_COMPARISONS[comparison](
+ trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
+ )
+ or expression
+ )
+
+ if isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
ranges = merge_ranges(ranges)
+ target_type = extract_type(l, *rs)
- return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
+ return exp.or_(
+ *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
+ )
return expression
@@ -954,7 +1006,7 @@ JOINS = {
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
- where.parent.set("where", None)
+ where.pop()
for join in expression.find_all(exp.Join):
if (
always_true(join.args.get("on"))
@@ -962,7 +1014,7 @@ def remove_where_true(expression):
and not join.args.get("method")
and (join.side, join.kind) in JOINS
):
- join.set("on", None)
+ join.args["on"].pop()
join.set("side", None)
join.set("kind", "CROSS")
@@ -1067,15 +1119,25 @@ def extract_interval(expression):
return None
-def date_literal(date):
- return exp.cast(
- exp.Literal.string(date),
- (
+def extract_type(*expressions):
+ target_type = None
+ for expression in expressions:
+ target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
+ if target_type:
+ break
+
+ return target_type
+
+
+def date_literal(date, target_type=None):
+ if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
+ target_type = (
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE
- ),
- )
+ )
+
+ return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
@@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str:
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
"""
- if expression is None:
- return "_"
- if is_iterable(expression):
- return ",".join(gen(e) for e in expression)
- if not isinstance(expression, exp.Expression):
- return str(expression)
-
- etype = type(expression)
- if etype in GEN_MAP:
- return GEN_MAP[etype](expression)
- return f"{expression.key} {gen(expression.args.values())}"
-
-
-GEN_MAP = {
- exp.Add: lambda e: _binary(e, "+"),
- exp.And: lambda e: _binary(e, "AND"),
- exp.Anonymous: lambda e: _anonymous(e),
- exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
- exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
- exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
- exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
- exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
- exp.Div: lambda e: _binary(e, "/"),
- exp.Dot: lambda e: _binary(e, "."),
- exp.EQ: lambda e: _binary(e, "="),
- exp.GT: lambda e: _binary(e, ">"),
- exp.GTE: lambda e: _binary(e, ">="),
- exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
- exp.ILike: lambda e: _binary(e, "ILIKE"),
- exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
- exp.Is: lambda e: _binary(e, "IS"),
- exp.Like: lambda e: _binary(e, "LIKE"),
- exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
- exp.LT: lambda e: _binary(e, "<"),
- exp.LTE: lambda e: _binary(e, "<="),
- exp.Mod: lambda e: _binary(e, "%"),
- exp.Mul: lambda e: _binary(e, "*"),
- exp.Neg: lambda e: _unary(e, "-"),
- exp.NEQ: lambda e: _binary(e, "<>"),
- exp.Not: lambda e: _unary(e, "NOT"),
- exp.Null: lambda e: "NULL",
- exp.Or: lambda e: _binary(e, "OR"),
- exp.Paren: lambda e: f"({gen(e.this)})",
- exp.Sub: lambda e: _binary(e, "-"),
- exp.Subquery: lambda e: f"({gen(e.args.values())})",
- exp.Table: lambda e: gen(e.args.values()),
- exp.Var: lambda e: e.name,
-}
+ return Gen().gen(expression)
+
+
+class Gen:
+ def __init__(self):
+ self.stack = []
+ self.sqls = []
+
+ def gen(self, expression: exp.Expression) -> str:
+ self.stack = [expression]
+ self.sqls.clear()
+
+ while self.stack:
+ node = self.stack.pop()
+
+ if isinstance(node, exp.Expression):
+ exp_handler_name = f"{node.key}_sql"
+
+ if hasattr(self, exp_handler_name):
+ getattr(self, exp_handler_name)(node)
+ elif isinstance(node, exp.Func):
+ self._function(node)
+ else:
+ key = node.key.upper()
+ self.stack.append(f"{key} " if self._args(node) else key)
+ elif type(node) is list:
+ for n in reversed(node):
+ if n is not None:
+ self.stack.extend((n, ","))
+ if node:
+ self.stack.pop()
+ else:
+ if node is not None:
+ self.sqls.append(str(node))
+ return "".join(self.sqls)
-def _anonymous(e: exp.Anonymous) -> str:
- this = e.this
- if isinstance(this, str):
- name = this.upper()
- elif isinstance(this, exp.Identifier):
- name = f'"{this.name}"' if this.quoted else this.name.upper()
- else:
- raise ValueError(
- f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
+ def add_sql(self, e: exp.Add) -> None:
+ self._binary(e, " + ")
+
+ def alias_sql(self, e: exp.Alias) -> None:
+ self.stack.extend(
+ (
+ e.args.get("alias"),
+ " AS ",
+ e.args.get("this"),
+ )
+ )
+
+ def and_sql(self, e: exp.And) -> None:
+ self._binary(e, " AND ")
+
+ def anonymous_sql(self, e: exp.Anonymous) -> None:
+ this = e.this
+ if isinstance(this, str):
+ name = this.upper()
+ elif isinstance(this, exp.Identifier):
+ name = this.this
+ name = f'"{name}"' if this.quoted else name.upper()
+ else:
+ raise ValueError(
+ f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
+ )
+
+ self.stack.extend(
+ (
+ ")",
+ e.expressions,
+ "(",
+ name,
+ )
+ )
+
+ def between_sql(self, e: exp.Between) -> None:
+ self.stack.extend(
+ (
+ e.args.get("high"),
+ " AND ",
+ e.args.get("low"),
+ " BETWEEN ",
+ e.this,
+ )
+ )
+
+ def boolean_sql(self, e: exp.Boolean) -> None:
+ self.stack.append("TRUE" if e.this else "FALSE")
+
+ def bracket_sql(self, e: exp.Bracket) -> None:
+ self.stack.extend(
+ (
+ "]",
+ e.expressions,
+ "[",
+ e.this,
+ )
+ )
+
+ def column_sql(self, e: exp.Column) -> None:
+ for p in reversed(e.parts):
+ self.stack.extend((p, "."))
+ self.stack.pop()
+
+ def datatype_sql(self, e: exp.DataType) -> None:
+ self._args(e, 1)
+ self.stack.append(f"{e.this.name} ")
+
+ def div_sql(self, e: exp.Div) -> None:
+ self._binary(e, " / ")
+
+ def dot_sql(self, e: exp.Dot) -> None:
+ self._binary(e, ".")
+
+ def eq_sql(self, e: exp.EQ) -> None:
+ self._binary(e, " = ")
+
+ def from_sql(self, e: exp.From) -> None:
+ self.stack.extend((e.this, "FROM "))
+
+ def gt_sql(self, e: exp.GT) -> None:
+ self._binary(e, " > ")
+
+ def gte_sql(self, e: exp.GTE) -> None:
+ self._binary(e, " >= ")
+
+ def identifier_sql(self, e: exp.Identifier) -> None:
+ self.stack.append(f'"{e.this}"' if e.quoted else e.this)
+
+ def ilike_sql(self, e: exp.ILike) -> None:
+ self._binary(e, " ILIKE ")
+
+ def in_sql(self, e: exp.In) -> None:
+ self.stack.append(")")
+ self._args(e, 1)
+ self.stack.extend(
+ (
+ "(",
+ " IN ",
+ e.this,
+ )
)
- return f"{name} {','.join(gen(e) for e in e.expressions)}"
+ def intdiv_sql(self, e: exp.IntDiv) -> None:
+ self._binary(e, " DIV ")
+
+ def is_sql(self, e: exp.Is) -> None:
+ self._binary(e, " IS ")
+
+ def like_sql(self, e: exp.Like) -> None:
+ self._binary(e, " Like ")
+
+ def literal_sql(self, e: exp.Literal) -> None:
+ self.stack.append(f"'{e.this}'" if e.is_string else e.this)
+
+ def lt_sql(self, e: exp.LT) -> None:
+ self._binary(e, " < ")
+
+ def lte_sql(self, e: exp.LTE) -> None:
+ self._binary(e, " <= ")
+
+ def mod_sql(self, e: exp.Mod) -> None:
+ self._binary(e, " % ")
+
+ def mul_sql(self, e: exp.Mul) -> None:
+ self._binary(e, " * ")
+ def neg_sql(self, e: exp.Neg) -> None:
+ self._unary(e, "-")
+
+ def neq_sql(self, e: exp.NEQ) -> None:
+ self._binary(e, " <> ")
+
+ def not_sql(self, e: exp.Not) -> None:
+ self._unary(e, "NOT ")
+
+ def null_sql(self, e: exp.Null) -> None:
+ self.stack.append("NULL")
+
+ def or_sql(self, e: exp.Or) -> None:
+ self._binary(e, " OR ")
+
+ def paren_sql(self, e: exp.Paren) -> None:
+ self.stack.extend(
+ (
+ ")",
+ e.this,
+ "(",
+ )
+ )
+
+ def sub_sql(self, e: exp.Sub) -> None:
+ self._binary(e, " - ")
+
+ def subquery_sql(self, e: exp.Subquery) -> None:
+ self._args(e, 2)
+ alias = e.args.get("alias")
+ if alias:
+ self.stack.append(alias)
+ self.stack.extend((")", e.this, "("))
+
+ def table_sql(self, e: exp.Table) -> None:
+ self._args(e, 4)
+ alias = e.args.get("alias")
+ if alias:
+ self.stack.append(alias)
+ for p in reversed(e.parts):
+ self.stack.extend((p, "."))
+ self.stack.pop()
+
+ def tablealias_sql(self, e: exp.TableAlias) -> None:
+ columns = e.columns
+
+ if columns:
+ self.stack.extend((")", columns, "("))
+
+ self.stack.extend((e.this, " AS "))
+
+ def var_sql(self, e: exp.Var) -> None:
+ self.stack.append(e.this)
+
+ def _binary(self, e: exp.Binary, op: str) -> None:
+ self.stack.extend((e.expression, op, e.this))
+
+ def _unary(self, e: exp.Unary, op: str) -> None:
+ self.stack.extend((e.this, op))
+
+ def _function(self, e: exp.Func) -> None:
+ self.stack.extend(
+ (
+ ")",
+ list(e.args.values()),
+ "(",
+ e.sql_name(),
+ )
+ )
-def _binary(e: exp.Binary, op: str) -> str:
- return f"{gen(e.left)} {op} {gen(e.right)}"
+ def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
+ kvs = []
+ arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
+ for k in arg_types or arg_types:
+ v = node.args.get(k)
-def _unary(e: exp.Unary, op: str) -> str:
- return f"{op} {gen(e.this)}"
+ if v is not None:
+ kvs.append([f":{k}", v])
+ if kvs:
+ self.stack.append(kvs)
+ return True
+ return False
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 36d9da4..b83abe6 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -138,7 +138,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
if isinstance(predicate, exp.Binary):
key = (
predicate.right
- if any(node is column for node, *_ in predicate.left.walk())
+ if any(node is column for node in predicate.left.walk())
else predicate.left
)
else:
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 49dac2e..91d8d13 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import Dialect, DialectType
+ T = t.TypeVar("T")
+
logger = logging.getLogger("sqlglot")
OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]]
@@ -119,6 +121,9 @@ class Parser(metaclass=_Parser):
"JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar),
"LIKE": build_like,
"LOG": build_logarithm,
+ "LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)),
+ "LOG10": lambda args: exp.Log(this=exp.Literal.number(10), expression=seq_get(args, 0)),
+ "MOD": lambda args: exp.Mod(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@@ -144,6 +149,7 @@ class Parser(metaclass=_Parser):
STRUCT_TYPE_TOKENS = {
TokenType.NESTED,
+ TokenType.OBJECT,
TokenType.STRUCT,
}
@@ -258,6 +264,7 @@ class Parser(metaclass=_Parser):
TokenType.IPV6,
TokenType.UNKNOWN,
TokenType.NULL,
+ TokenType.NAME,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
*AGGREGATE_TYPE_TOKENS,
@@ -291,6 +298,7 @@ class Parser(metaclass=_Parser):
TokenType.VIEW,
TokenType.MODEL,
TokenType.DICTIONARY,
+ TokenType.SEQUENCE,
TokenType.STORAGE_INTEGRATION,
}
@@ -310,6 +318,7 @@ class Parser(metaclass=_Parser):
TokenType.ANTI,
TokenType.APPLY,
TokenType.ASC,
+ TokenType.ASOF,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.BPCHAR,
@@ -398,6 +407,8 @@ class Parser(metaclass=_Parser):
TokenType.WINDOW,
}
+ ALIAS_TOKENS = ID_VAR_TOKENS
+
COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
@@ -433,6 +444,7 @@ class Parser(metaclass=_Parser):
TokenType.VAR,
TokenType.LEFT,
TokenType.RIGHT,
+ TokenType.SEQUENCE,
TokenType.DATE,
TokenType.DATETIME,
TokenType.TABLE,
@@ -505,8 +517,9 @@ class Parser(metaclass=_Parser):
}
JOIN_METHODS = {
- TokenType.NATURAL,
TokenType.ASOF,
+ TokenType.NATURAL,
+ TokenType.POSITIONAL,
}
JOIN_SIDES = {
@@ -611,8 +624,8 @@ class Parser(metaclass=_Parser):
TokenType.ALTER: lambda self: self._parse_alter(),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.CACHE: lambda self: self._parse_cache(),
- TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.COMMENT: lambda self: self._parse_comment(),
+ TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESC: lambda self: self._parse_describe(),
@@ -627,9 +640,9 @@ class Parser(metaclass=_Parser):
TokenType.REFRESH: lambda self: self._parse_refresh(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
+ TokenType.TRUNCATE: lambda self: self._parse_truncate_table(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
- TokenType.TRUNCATE: lambda self: self._parse_truncate_table(),
TokenType.USE: lambda self: self.expression(
exp.Use,
kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False),
@@ -714,6 +727,9 @@ class Parser(metaclass=_Parser):
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO": lambda self: self._parse_auto_property(),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
+ "BACKUP": lambda self: self.expression(
+ exp.BackupProperty, this=self._parse_var(any_token=True)
+ ),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs),
"CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs),
@@ -739,7 +755,9 @@ class Parser(metaclass=_Parser):
"FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
+ "GLOBAL": lambda self: self.expression(exp.GlobalProperty),
"HEAP": lambda self: self.expression(exp.HeapProperty),
+ "ICEBERG": lambda self: self.expression(exp.IcebergProperty),
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
@@ -782,6 +800,7 @@ class Parser(metaclass=_Parser):
"SETTINGS": lambda self: self.expression(
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
),
+ "SHARING": lambda self: self._parse_property_assignment(exp.SharingProperty),
"SORTKEY": lambda self: self._parse_sortkey(),
"SOURCE": lambda self: self._parse_dict_property(this="SOURCE"),
"STABLE": lambda self: self.expression(
@@ -789,7 +808,7 @@ class Parser(metaclass=_Parser):
),
"STORED": lambda self: self._parse_stored(),
"SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(),
- "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
+ "TBLPROPERTIES": lambda self: self._parse_wrapped_properties(),
"TEMP": lambda self: self.expression(exp.TemporaryProperty),
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
"TO": lambda self: self._parse_to_table(),
@@ -799,6 +818,7 @@ class Parser(metaclass=_Parser):
),
"TTL": lambda self: self._parse_ttl(),
"USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
+ "UNLOGGED": lambda self: self.expression(exp.UnloggedProperty),
"VOLATILE": lambda self: self._parse_volatile_property(),
"WITH": lambda self: self._parse_with_property(),
}
@@ -832,6 +852,9 @@ class Parser(metaclass=_Parser):
exp.DefaultColumnConstraint, this=self._parse_bitwise()
),
"ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()),
+ "EXCLUDE": lambda self: self.expression(
+ exp.ExcludeColumnConstraint, this=self._parse_index_params()
+ ),
"FOREIGN KEY": lambda self: self._parse_foreign_key(),
"FORMAT": lambda self: self.expression(
exp.DateFormatColumnConstraint, this=self._parse_var_or_string()
@@ -858,7 +881,7 @@ class Parser(metaclass=_Parser):
"UNIQUE": lambda self: self._parse_unique(),
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
"WITH": lambda self: self.expression(
- exp.Properties, expressions=self._parse_wrapped_csv(self._parse_property)
+ exp.Properties, expressions=self._parse_wrapped_properties()
),
}
@@ -871,7 +894,15 @@ class Parser(metaclass=_Parser):
"RENAME": lambda self: self._parse_alter_table_rename(),
}
- SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"}
+ SCHEMA_UNNAMED_CONSTRAINTS = {
+ "CHECK",
+ "EXCLUDE",
+ "FOREIGN KEY",
+ "LIKE",
+ "PERIOD",
+ "PRIMARY KEY",
+ "UNIQUE",
+ }
NO_PAREN_FUNCTION_PARSERS = {
"ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
@@ -966,18 +997,54 @@ class Parser(metaclass=_Parser):
"READ": ("WRITE", "ONLY"),
}
+ CONFLICT_ACTIONS: OPTIONS_TYPE = dict.fromkeys(
+ ("ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK", "UPDATE"), tuple()
+ )
+ CONFLICT_ACTIONS["DO"] = ("NOTHING", "UPDATE")
+
+ CREATE_SEQUENCE: OPTIONS_TYPE = {
+ "SCALE": ("EXTEND", "NOEXTEND"),
+ "SHARD": ("EXTEND", "NOEXTEND"),
+ "NO": ("CYCLE", "CACHE", "MAXVALUE", "MINVALUE"),
+ **dict.fromkeys(
+ (
+ "SESSION",
+ "GLOBAL",
+ "KEEP",
+ "NOKEEP",
+ "ORDER",
+ "NOORDER",
+ "NOCACHE",
+ "CYCLE",
+ "NOCYCLE",
+ "NOMINVALUE",
+ "NOMAXVALUE",
+ "NOSCALE",
+ "NOSHARD",
+ ),
+ tuple(),
+ ),
+ }
+
+ ISOLATED_LOADING_OPTIONS: OPTIONS_TYPE = {"FOR": ("ALL", "INSERT", "NONE")}
+
USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple())
+ CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",))
+
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
CLONE_KEYWORDS = {"CLONE", "COPY"}
HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}
- OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
+ OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"}
+
OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
+ VIEW_ATTRIBUTES = {"ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"}
+
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
@@ -994,6 +1061,8 @@ class Parser(metaclass=_Parser):
UNNEST_OFFSET_ALIAS_TOKENS = ID_VAR_TOKENS - SET_OPERATIONS
+ SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT}
+
STRICT_CAST = True
PREFIXED_PIVOT_COLUMNS = False
@@ -1033,6 +1102,9 @@ class Parser(metaclass=_Parser):
# Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift)
SUPPORTS_IMPLICIT_UNNEST = False
+ # Whether or not interval spans are supported, INTERVAL 1 YEAR TO MONTHS
+ INTERVAL_SPANS = True
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1285,6 +1357,27 @@ class Parser(metaclass=_Parser):
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
)
+ def _try_parse(self, parse_method: t.Callable[[], T], retreat: bool = False) -> t.Optional[T]:
+ """
+ Attemps to backtrack if a parse function that contains a try/catch internally raises an error. This behavior can
+ be different depending on the uset-set ErrorLevel, so _try_parse aims to solve this by setting & resetting
+ the parser state accordingly
+ """
+ index = self._index
+ error_level = self.error_level
+
+ self.error_level = ErrorLevel.IMMEDIATE
+ try:
+ this = parse_method()
+ except ParseError:
+ this = None
+ finally:
+ if not this or retreat:
+ self._retreat(index)
+ self.error_level = error_level
+
+ return this
+
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
start = self._prev
exists = self._parse_exists() if allow_exists else None
@@ -1377,13 +1470,22 @@ class Parser(metaclass=_Parser):
if not kind:
return self._parse_as_command(start)
+ if_exists = exists or self._parse_exists()
+ table = self._parse_table_parts(
+ schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
+ )
+
+ if self._match(TokenType.L_PAREN, advance=False):
+ expressions = self._parse_wrapped_csv(self._parse_types)
+ else:
+ expressions = None
+
return self.expression(
exp.Drop,
comments=start.comments,
- exists=exists or self._parse_exists(),
- this=self._parse_table(
- schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
- ),
+ exists=if_exists,
+ this=table,
+ expressions=expressions,
kind=kind,
temporary=temporary,
materialized=materialized,
@@ -1409,6 +1511,7 @@ class Parser(metaclass=_Parser):
or self._match_pair(TokenType.OR, TokenType.REPLACE)
or self._match_pair(TokenType.OR, TokenType.ALTER)
)
+
unique = self._match(TokenType.UNIQUE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
@@ -1489,7 +1592,11 @@ class Parser(metaclass=_Parser):
# exp.Properties.Location.POST_ALIAS
extend_props(self._parse_properties())
- expression = self._parse_ddl_select()
+ if create_token.token_type == TokenType.SEQUENCE:
+ expression = self._parse_types()
+ extend_props(self._parse_properties())
+ else:
+ expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
# exp.Properties.Location.POST_EXPRESSION
@@ -1539,6 +1646,40 @@ class Parser(metaclass=_Parser):
clone=clone,
)
+ def _parse_sequence_properties(self) -> t.Optional[exp.SequenceProperties]:
+ seq = exp.SequenceProperties()
+
+ options = []
+ index = self._index
+
+ while self._curr:
+ if self._match_text_seq("INCREMENT"):
+ self._match_text_seq("BY")
+ self._match_text_seq("=")
+ seq.set("increment", self._parse_term())
+ elif self._match_text_seq("MINVALUE"):
+ seq.set("minvalue", self._parse_term())
+ elif self._match_text_seq("MAXVALUE"):
+ seq.set("maxvalue", self._parse_term())
+ elif self._match(TokenType.START_WITH) or self._match_text_seq("START"):
+ self._match_text_seq("=")
+ seq.set("start", self._parse_term())
+ elif self._match_text_seq("CACHE"):
+ # T-SQL allows empty CACHE which is initialized dynamically
+ seq.set("cache", self._parse_number() or True)
+ elif self._match_text_seq("OWNED", "BY"):
+ # "OWNED BY NONE" is the default
+ seq.set("owned", None if self._match_text_seq("NONE") else self._parse_column())
+ else:
+ opt = self._parse_var_from_options(self.CREATE_SEQUENCE, raise_unmatched=False)
+ if opt:
+ options.append(opt)
+ else:
+ break
+
+ seq.set("options", options if options else None)
+ return None if self._index == index else seq
+
def _parse_property_before(self) -> t.Optional[exp.Expression]:
# only used for teradata currently
self._match(TokenType.COMMA)
@@ -1564,6 +1705,9 @@ class Parser(metaclass=_Parser):
return None
+ def _parse_wrapped_properties(self) -> t.List[exp.Expression]:
+ return self._parse_wrapped_csv(self._parse_property)
+
def _parse_property(self) -> t.Optional[exp.Expression]:
if self._match_texts(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.text.upper()](self)
@@ -1582,12 +1726,12 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.EQ):
self._retreat(index)
- return None
+ return self._parse_sequence_properties()
return self.expression(
exp.Property,
this=key.to_dot() if isinstance(key, exp.Column) else key,
- value=self._parse_column() or self._parse_var(any_token=True),
+ value=self._parse_bitwise() or self._parse_var(any_token=True),
)
def _parse_stored(self) -> exp.FileFormatProperty:
@@ -1619,7 +1763,6 @@ class Parser(metaclass=_Parser):
prop = self._parse_property_before()
else:
prop = self._parse_property()
-
if not prop:
break
for p in ensure_list(prop):
@@ -1662,15 +1805,16 @@ class Parser(metaclass=_Parser):
return prop
- def _parse_with_property(
- self,
- ) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
+ def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
if self._match(TokenType.L_PAREN, advance=False):
- return self._parse_wrapped_csv(self._parse_property)
+ return self._parse_wrapped_properties()
if self._match_text_seq("JOURNAL"):
return self._parse_withjournaltable()
+ if self._match_texts(self.VIEW_ATTRIBUTES):
+ return self.expression(exp.ViewAttributeProperty, this=self._prev.text.upper())
+
if self._match_text_seq("DATA"):
return self._parse_withdata(no=False)
elif self._match_text_seq("NO", "DATA"):
@@ -1818,20 +1962,18 @@ class Parser(metaclass=_Parser):
autotemp=autotemp,
)
- def _parse_withisolatedloading(self) -> exp.IsolatedLoadingProperty:
+ def _parse_withisolatedloading(self) -> t.Optional[exp.IsolatedLoadingProperty]:
+ index = self._index
no = self._match_text_seq("NO")
concurrent = self._match_text_seq("CONCURRENT")
- self._match_text_seq("ISOLATED", "LOADING")
- for_all = self._match_text_seq("FOR", "ALL")
- for_insert = self._match_text_seq("FOR", "INSERT")
- for_none = self._match_text_seq("FOR", "NONE")
+
+ if not self._match_text_seq("ISOLATED", "LOADING"):
+ self._retreat(index)
+ return None
+
+ target = self._parse_var_from_options(self.ISOLATED_LOADING_OPTIONS, raise_unmatched=False)
return self.expression(
- exp.IsolatedLoadingProperty,
- no=no,
- concurrent=concurrent,
- for_all=for_all,
- for_insert=for_insert,
- for_none=for_none,
+ exp.IsolatedLoadingProperty, no=no, concurrent=concurrent, target=target
)
def _parse_locking(self) -> exp.LockingProperty:
@@ -2046,20 +2188,22 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
- extended = self._match_text_seq("EXTENDED")
+ style = self._match_texts(("EXTENDED", "FORMATTED")) and self._prev.text.upper()
this = self._parse_table(schema=True)
properties = self._parse_properties()
expressions = properties.expressions if properties else None
return self.expression(
- exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions
+ exp.Describe, this=this, style=style, kind=kind, expressions=expressions
)
def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
+ hint = self._parse_hint()
overwrite = self._match(TokenType.OVERWRITE)
ignore = self._match(TokenType.IGNORE)
local = self._match_text_seq("LOCAL")
alternative = None
+ is_function = None
if self._match_text_seq("DIRECTORY"):
this: t.Optional[exp.Expression] = self.expression(
@@ -2075,13 +2219,17 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
comments += ensure_list(self._prev_comments)
self._match(TokenType.TABLE)
- this = self._parse_table(schema=True)
+ is_function = self._match(TokenType.FUNCTION)
+
+ this = self._parse_table(schema=True) if not is_function else self._parse_function()
returning = self._parse_returning()
return self.expression(
exp.Insert,
comments=comments,
+ hint=hint,
+ is_function=is_function,
this=this,
by_name=self._match_text_seq("BY", "NAME"),
exists=self._parse_exists(),
@@ -2112,31 +2260,29 @@ class Parser(metaclass=_Parser):
if not conflict and not duplicate:
return None
- nothing = None
- expressions = None
- key = None
+ conflict_keys = None
constraint = None
if conflict:
if self._match_text_seq("ON", "CONSTRAINT"):
constraint = self._parse_id_var()
- else:
- key = self._parse_csv(self._parse_value)
+ elif self._match(TokenType.L_PAREN):
+ conflict_keys = self._parse_csv(self._parse_id_var)
+ self._match_r_paren()
- self._match_text_seq("DO")
- if self._match_text_seq("NOTHING"):
- nothing = True
- else:
- self._match(TokenType.UPDATE)
+ action = self._parse_var_from_options(self.CONFLICT_ACTIONS)
+ if self._prev.token_type == TokenType.UPDATE:
self._match(TokenType.SET)
expressions = self._parse_csv(self._parse_equality)
+ else:
+ expressions = None
return self.expression(
exp.OnConflict,
duplicate=duplicate,
expressions=expressions,
- nothing=nothing,
- key=key,
+ action=action,
+ conflict_keys=conflict_keys,
constraint=constraint,
)
@@ -2166,7 +2312,7 @@ class Parser(metaclass=_Parser):
serde_properties = None
if self._match(TokenType.SERDE_PROPERTIES):
serde_properties = self.expression(
- exp.SerdeProperties, expressions=self._parse_wrapped_csv(self._parse_property)
+ exp.SerdeProperties, expressions=self._parse_wrapped_properties()
)
return self.expression(
@@ -2433,8 +2579,19 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected CTE to have alias")
self._match(TokenType.ALIAS)
+
+ if self._match_text_seq("NOT", "MATERIALIZED"):
+ materialized = False
+ elif self._match_text_seq("MATERIALIZED"):
+ materialized = True
+ else:
+ materialized = None
+
return self.expression(
- exp.CTE, this=self._parse_wrapped(self._parse_statement), alias=alias
+ exp.CTE,
+ this=self._parse_wrapped(self._parse_statement),
+ alias=alias,
+ materialized=materialized,
)
def _parse_table_alias(
@@ -2472,7 +2629,9 @@ class Parser(metaclass=_Parser):
)
def _implicit_unnests_to_explicit(self, this: E) -> E:
- from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm
+ from sqlglot.optimizer.normalize_identifiers import (
+ normalize_identifiers as _norm,
+ )
refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name}
for i, join in enumerate(this.args.get("joins") or []):
@@ -2502,7 +2661,7 @@ class Parser(metaclass=_Parser):
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
if isinstance(this, (exp.Query, exp.Table)):
- for join in iter(self._parse_join, None):
+ for join in self._parse_joins():
this.append("joins", join)
for lateral in iter(self._parse_lateral, None):
this.append("laterals", lateral)
@@ -2535,7 +2694,12 @@ class Parser(metaclass=_Parser):
def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
hints = []
- for hint in iter(lambda: self._parse_csv(self._parse_function), []):
+ for hint in iter(
+ lambda: self._parse_csv(
+ lambda: self._parse_function() or self._parse_var(upper=True)
+ ),
+ [],
+ ):
hints.extend(hint)
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
@@ -2743,29 +2907,35 @@ class Parser(metaclass=_Parser):
if hint:
kwargs["hint"] = hint
+ if self._match(TokenType.MATCH_CONDITION):
+ kwargs["match_condition"] = self._parse_wrapped(self._parse_comparison)
+
if self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
- elif not (kind and kind.token_type == TokenType.CROSS):
+ elif not isinstance(kwargs["this"], exp.Unnest) and not (
+ kind and kind.token_type == TokenType.CROSS
+ ):
index = self._index
- join = self._parse_join()
+ joins: t.Optional[list] = list(self._parse_joins())
- if join and self._match(TokenType.ON):
+ if joins and self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
- elif join and self._match(TokenType.USING):
+ elif joins and self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
else:
- join = None
+ joins = None
self._retreat(index)
- kwargs["this"].set("joins", [join] if join else None)
+ kwargs["this"].set("joins", joins if joins else None)
comments = [c for token in (method, side, kind) if token for c in token.comments]
return self.expression(exp.Join, comments=comments, **kwargs)
def _parse_opclass(self) -> t.Optional[exp.Expression]:
this = self._parse_conjunction()
+
if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
return this
@@ -2774,6 +2944,35 @@ class Parser(metaclass=_Parser):
return this
+ def _parse_index_params(self) -> exp.IndexParameters:
+ using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None
+
+ if self._match(TokenType.L_PAREN, advance=False):
+ columns = self._parse_wrapped_csv(self._parse_with_operator)
+ else:
+ columns = None
+
+ include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
+ partition_by = self._parse_partition_by()
+ with_storage = self._match(TokenType.WITH) and self._parse_wrapped_properties()
+ tablespace = (
+ self._parse_var(any_token=True)
+ if self._match_text_seq("USING", "INDEX", "TABLESPACE")
+ else None
+ )
+ where = self._parse_where()
+
+ return self.expression(
+ exp.IndexParameters,
+ using=using,
+ columns=columns,
+ include=include,
+ partition_by=partition_by,
+ where=where,
+ with_storage=with_storage,
+ tablespace=tablespace,
+ )
+
def _parse_index(
self,
index: t.Optional[exp.Expression] = None,
@@ -2797,27 +2996,16 @@ class Parser(metaclass=_Parser):
index = self._parse_id_var()
table = None
- using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None
-
- if self._match(TokenType.L_PAREN, advance=False):
- columns = self._parse_wrapped_csv(lambda: self._parse_ordered(self._parse_opclass))
- else:
- columns = None
-
- include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
+ params = self._parse_index_params()
return self.expression(
exp.Index,
this=index,
table=table,
- using=using,
- columns=columns,
unique=unique,
primary=primary,
amp=amp,
- include=include,
- partition_by=self._parse_partition_by(),
- where=self._parse_where(),
+ params=params,
)
def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]:
@@ -2977,7 +3165,7 @@ class Parser(metaclass=_Parser):
this = table_sample
if joins:
- for join in iter(self._parse_join, None):
+ for join in self._parse_joins():
this.append("joins", join)
if self._match_pair(TokenType.WITH, TokenType.ORDINALITY):
@@ -3126,8 +3314,8 @@ class Parser(metaclass=_Parser):
def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
return list(iter(self._parse_pivot, None)) or None
- def _parse_joins(self) -> t.Optional[t.List[exp.Join]]:
- return list(iter(self._parse_join, None)) or None
+ def _parse_joins(self) -> t.Iterator[exp.Join]:
+ return iter(self._parse_join, None)
# https://duckdb.org/docs/sql/statements/pivot
def _parse_simplified_pivot(self) -> exp.Pivot:
@@ -3328,6 +3516,7 @@ class Parser(metaclass=_Parser):
return None
self._match(TokenType.CONNECT_BY)
+ nocycle = self._match_text_seq("NOCYCLE")
self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression(
exp.Prior, this=self._parse_bitwise()
)
@@ -3337,7 +3526,7 @@ class Parser(metaclass=_Parser):
if not start and self._match(TokenType.START_WITH):
start = self._parse_conjunction()
- return self.expression(exp.Connect, start=start, connect=connect)
+ return self.expression(exp.Connect, start=start, connect=connect, nocycle=nocycle)
def _parse_name_as_expression(self) -> exp.Alias:
return self.expression(
@@ -3417,9 +3606,12 @@ class Parser(metaclass=_Parser):
)
def _parse_limit(
- self, this: t.Optional[exp.Expression] = None, top: bool = False
+ self,
+ this: t.Optional[exp.Expression] = None,
+ top: bool = False,
+ skip_limit_token: bool = False,
) -> t.Optional[exp.Expression]:
- if self._match(TokenType.TOP if top else TokenType.LIMIT):
+ if skip_limit_token or self._match(TokenType.TOP if top else TokenType.LIMIT):
comments = self._prev_comments
if top:
limit_paren = self._match(TokenType.L_PAREN)
@@ -3681,6 +3873,11 @@ class Parser(metaclass=_Parser):
this = exp.Literal.string(parts[0])
unit = self.expression(exp.Var, this=parts[1].upper())
+ if self.INTERVAL_SPANS and self._match_text_seq("TO"):
+ unit = self.expression(
+ exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True)
+ )
+
return self.expression(exp.Interval, this=this, unit=unit)
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
@@ -3783,6 +3980,9 @@ class Parser(metaclass=_Parser):
if not this:
return None
+ if isinstance(this, exp.Column) and not this.table:
+ this = exp.var(this.name.upper())
+
return self.expression(
exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True)
)
@@ -3900,19 +4100,14 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
maybe_func = False
elif type_token == TokenType.INTERVAL:
- unit = self._parse_var()
-
- if self._match_text_seq("TO"):
- span = [exp.IntervalSpan(this=unit, expression=self._parse_var())]
- else:
- span = None
+ unit = self._parse_var(upper=True)
+ if unit:
+ if self._match_text_seq("TO"):
+ unit = exp.IntervalSpan(this=unit, expression=self._parse_var(upper=True))
- if span or not unit:
- this = self.expression(
- exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
- )
- else:
this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit))
+ else:
+ this = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL)
if maybe_func and check_func:
index2 = self._index
@@ -3996,11 +4191,20 @@ class Parser(metaclass=_Parser):
else:
field = self._parse_field(anonymous_func=True, any_token=True)
- if isinstance(field, exp.Func):
+ if isinstance(field, exp.Func) and this:
# bigquery allows function calls like x.y.count(...)
# SAFE.SUBSTR(...)
# https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules
- this = self._replace_columns_with_dots(this)
+ this = exp.replace_tree(
+ this,
+ lambda n: (
+ self.expression(exp.Dot, this=n.args.get("table"), expression=n.this)
+ if n.table
+ else n.this
+ )
+ if isinstance(n, exp.Column)
+ else n,
+ )
if op:
this = op(self, this, field)
@@ -4050,10 +4254,14 @@ class Parser(metaclass=_Parser):
this = self._parse_set_operations(
self._parse_subquery(this=this, parse_alias=False)
)
+ elif isinstance(this, exp.Subquery):
+ this = self._parse_subquery(
+ this=self._parse_set_operations(this), parse_alias=False
+ )
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
- this = self.expression(exp.Paren, this=self._parse_set_operations(this))
+ this = self.expression(exp.Paren, this=this)
if this:
this.add_comments(comments)
@@ -4118,7 +4326,7 @@ class Parser(metaclass=_Parser):
parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper)
if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS:
self._advance()
- return parser(self)
+ return self._parse_window(parser(self))
if not self._next or self._next.token_type != TokenType.L_PAREN:
if optional_parens and token_type in self.NO_PAREN_FUNCTIONS:
@@ -4186,7 +4394,7 @@ class Parser(metaclass=_Parser):
if not isinstance(e, exp.PropertyEQ):
e = self.expression(
- exp.PropertyEQ, this=exp.to_identifier(e.name), expression=e.expression
+ exp.PropertyEQ, this=exp.to_identifier(e.this.name), expression=e.expression
)
if isinstance(e.this, exp.Column):
@@ -4267,19 +4475,15 @@ class Parser(metaclass=_Parser):
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
- if not self.errors:
- try:
- if self._parse_select(nested=True):
- return this
- except ParseError:
- pass
- finally:
- self.errors.clear()
- self._retreat(index)
-
if not self._match(TokenType.L_PAREN):
return this
+ # Disambiguate between schema and subquery/CTE, e.g. in INSERT INTO table (<expr>),
+ # expr can be of both types
+ if self._match_set(self.SELECT_START_TOKENS):
+ self._retreat(index)
+ return this
+
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def())
self._match_r_paren()
@@ -4300,7 +4504,7 @@ class Parser(metaclass=_Parser):
constraints: t.List[exp.Expression] = []
- if not kind and self._match(TokenType.ALIAS):
+ if (not kind and self._match(TokenType.ALIAS)) or self._match_text_seq("ALIAS"):
constraints.append(
self.expression(
exp.ComputedColumnConstraint,
@@ -4417,9 +4621,7 @@ class Parser(metaclass=_Parser):
self._match_text_seq("LENGTH")
return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise())
- def _parse_not_constraint(
- self,
- ) -> t.Optional[exp.Expression]:
+ def _parse_not_constraint(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("NULL"):
return self.expression(exp.NotNullColumnConstraint)
if self._match_text_seq("CASESPECIFIC"):
@@ -4447,16 +4649,21 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.CONSTRAINT):
return self._parse_unnamed_constraint(constraints=self.SCHEMA_UNNAMED_CONSTRAINTS)
- this = self._parse_id_var()
- expressions = []
+ return self.expression(
+ exp.Constraint,
+ this=self._parse_id_var(),
+ expressions=self._parse_unnamed_constraints(),
+ )
+ def _parse_unnamed_constraints(self) -> t.List[exp.Expression]:
+ constraints = []
while True:
constraint = self._parse_unnamed_constraint() or self._parse_function()
if not constraint:
break
- expressions.append(constraint)
+ constraints.append(constraint)
- return self.expression(exp.Constraint, this=this, expressions=expressions)
+ return constraints
def _parse_unnamed_constraint(
self, constraints: t.Optional[t.Collection[str]] = None
@@ -4478,6 +4685,7 @@ class Parser(metaclass=_Parser):
exp.UniqueColumnConstraint,
this=self._parse_schema(self._parse_id_var(any_token=False)),
index_type=self._match(TokenType.USING) and self._advance_any() and self._prev.text,
+ on_conflict=self._parse_on_conflict(),
)
def _parse_key_constraint_options(self) -> t.List[str]:
@@ -4592,7 +4800,7 @@ class Parser(metaclass=_Parser):
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
- def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ def _parse_bracket(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
@@ -4601,9 +4809,9 @@ class Parser(metaclass=_Parser):
lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE)
)
- if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
+ if bracket_kind == TokenType.L_BRACKET and not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")
- elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
+ elif bracket_kind == TokenType.L_BRACE and not self._match(TokenType.R_BRACE):
self.raise_error("Expected }")
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
@@ -4645,8 +4853,8 @@ class Parser(metaclass=_Parser):
else:
self.raise_error("Expected END after CASE", self._prev)
- return self._parse_window(
- self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
+ return self.expression(
+ exp.Case, comments=comments, this=expression, ifs=ifs, default=default
)
def _parse_if(self) -> t.Optional[exp.Expression]:
@@ -4672,7 +4880,7 @@ class Parser(metaclass=_Parser):
self._match(TokenType.END)
this = self.expression(exp.If, this=condition, true=true, false=false)
- return self._parse_window(this)
+ return this
def _parse_next_value_for(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("VALUE", "FOR"):
@@ -4739,7 +4947,12 @@ class Parser(metaclass=_Parser):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
return self.expression(
- exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
+ exp.Cast if strict else exp.TryCast,
+ this=this,
+ to=to,
+ format=fmt,
+ safe=safe,
+ action=self._parse_var_from_options(self.CAST_ACTIONS, raise_unmatched=False),
)
def _parse_string_agg(self) -> exp.Expression:
@@ -5087,6 +5300,9 @@ class Parser(metaclass=_Parser):
def _parse_window(
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:
+ func = this
+ comments = func.comments if isinstance(func, exp.Expression) else None
+
if self._match_pair(TokenType.FILTER, TokenType.L_PAREN):
self._match(TokenType.WHERE)
this = self.expression(
@@ -5132,9 +5348,16 @@ class Parser(metaclass=_Parser):
else:
over = self._prev.text.upper()
+ if comments:
+ func.comments = None # type: ignore
+
if not self._match(TokenType.L_PAREN):
return self.expression(
- exp.Window, this=this, alias=self._parse_id_var(False), over=over
+ exp.Window,
+ comments=comments,
+ this=this,
+ alias=self._parse_id_var(False),
+ over=over,
)
window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
@@ -5167,6 +5390,7 @@ class Parser(metaclass=_Parser):
window = self.expression(
exp.Window,
+ comments=comments,
this=this,
partition_by=partition,
order=order,
@@ -5218,7 +5442,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren(aliases)
return aliases
- alias = self._parse_id_var(any_token) or (
+ alias = self._parse_id_var(any_token, tokens=self.ALIAS_TOKENS) or (
self.STRING_ALIASES and self._parse_string_as_identifier()
)
@@ -5512,10 +5736,11 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
self._match_text_seq("SET", "DATA")
+ self._match_text_seq("TYPE")
return self.expression(
exp.AlterColumn,
this=column,
- dtype=self._match_text_seq("TYPE") and self._parse_types(),
+ dtype=self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
@@ -5919,26 +6144,6 @@ class Parser(metaclass=_Parser):
return True
- @t.overload
- def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
-
- @t.overload
- def _replace_columns_with_dots(
- self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]: ...
-
- def _replace_columns_with_dots(self, this):
- if isinstance(this, exp.Dot):
- exp.replace_children(this, self._replace_columns_with_dots)
- elif isinstance(this, exp.Column):
- exp.replace_children(this, self._replace_columns_with_dots)
- table = this.args.get("table")
- this = (
- self.expression(exp.Dot, this=table, expression=this.this) if table else this.this
- )
-
- return this
-
def _replace_lambda(
self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str]
) -> t.Optional[exp.Expression]:
@@ -6011,3 +6216,13 @@ class Parser(metaclass=_Parser):
option=option,
partition=partition,
)
+
+ def _parse_with_operator(self) -> t.Optional[exp.Expression]:
+ this = self._parse_ordered(self._parse_opclass)
+
+ if not self._match(TokenType.WITH):
+ return this
+
+ op = self._parse_var(any_token=True)
+
+ return self.expression(exp.WithOperator, this=this, op=op)
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index bbc52ab..5e4e23a 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -118,6 +118,7 @@ class Step:
if joins:
join = Join.from_joins(joins, ctes)
join.name = step.name
+ join.source_name = step.name
join.add_dependency(step)
step = join
@@ -187,13 +188,13 @@ class Step:
intermediate[v.name] = k
for projection in projections:
- for node, *_ in projection.walk():
+ for node in projection.walk():
name = intermediate.get(node)
if name:
node.replace(exp.column(name, step.name))
if aggregate.condition:
- for node, *_ in aggregate.condition.walk():
+ for node in aggregate.condition.walk():
name = intermediate.get(node) or intermediate.get(node.name)
if name:
node.replace(exp.column(name, step.name))
@@ -331,7 +332,7 @@ class Join(Step):
@classmethod
def from_joins(
cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
- ) -> Step:
+ ) -> Join:
step = Join()
for join in joins:
@@ -349,10 +350,11 @@ class Join(Step):
def __init__(self) -> None:
super().__init__()
+ self.source_name: t.Optional[str] = None
self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
def _to_s(self, indent: str) -> t.List[str]:
- lines = []
+ lines = [f"{indent}Source: {self.source_name or self.name}"]
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
@@ -423,7 +425,7 @@ class SetOperation(Step):
@classmethod
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
- ) -> Step:
+ ) -> SetOperation:
assert isinstance(expression, exp.Union)
left = Step.from_expression(expression.left, ctes)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index da9df7d..7f0cb5d 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -135,6 +135,7 @@ class TokenType(AutoName):
LONGBLOB = auto()
TINYBLOB = auto()
TINYTEXT = auto()
+ NAME = auto()
BINARY = auto()
VARBINARY = auto()
JSON = auto()
@@ -290,6 +291,7 @@ class TokenType(AutoName):
LOAD = auto()
LOCK = auto()
MAP = auto()
+ MATCH_CONDITION = auto()
MATCH_RECOGNIZE = auto()
MEMBER_OF = auto()
MERGE = auto()
@@ -317,6 +319,7 @@ class TokenType(AutoName):
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
+ POSITIONAL = auto()
PRAGMA = auto()
PREWHERE = auto()
PRIMARY_KEY = auto()
@@ -340,6 +343,7 @@ class TokenType(AutoName):
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
+ SEQUENCE = auto()
SERDE_PROPERTIES = auto()
SET = auto()
SETTINGS = auto()
@@ -518,6 +522,7 @@ class _Tokenizer(type):
break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
+ raw_string=_TOKEN_TYPE_TO_INDEX[TokenType.RAW_STRING],
hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
@@ -562,8 +567,7 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"@": TokenType.PARAMETER,
- # used for breaking a var like x'y' but nothing else
- # the token type doesn't matter
+ # Used for breaking a var like x'y' but nothing else the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
@@ -796,6 +800,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
+ "UINT": TokenType.UINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"BIGDECIMAL": TokenType.BIGDECIMAL,
@@ -856,6 +861,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
+ "SEQUENCE": TokenType.SEQUENCE,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.ALTER,
"ANALYZE": TokenType.COMMAND,
@@ -888,7 +894,7 @@ class Tokenizer(metaclass=_Tokenizer):
COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
- # handle numeric literals like in hive (3L = BIGINT)
+ # Handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}
COMMENTS = ["--", ("/*", "*/")]
@@ -917,7 +923,7 @@ class Tokenizer(metaclass=_Tokenizer):
if USE_RS_TOKENIZER:
self._rs_dialect_settings = RsTokenizerDialectSettings(
- escape_sequences=self.dialect.ESCAPE_SEQUENCES,
+ unescaped_sequences=self.dialect.UNESCAPED_SEQUENCES,
identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
)
@@ -961,8 +967,7 @@ class Tokenizer(metaclass=_Tokenizer):
while self.size and not self._end:
current = self._current
- # skip spaces inline rather than iteratively call advance()
- # for performance reasons
+ # Skip spaces here rather than iteratively calling advance() for performance reasons
while current < self.size:
char = self.sql[current]
@@ -971,12 +976,10 @@ class Tokenizer(metaclass=_Tokenizer):
else:
break
- n = current - self._current
- self._start = current
- self._advance(n if n > 1 else 1)
+ offset = current - self._current if current > self._current else 1
- if self._char is None:
- break
+ self._start = current
+ self._advance(offset)
if not self._char.isspace():
if self._char.isdigit():
@@ -1004,12 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
# Ensures we don't count an extra line if we get a \r\n line break sequence
- if self._char == "\r" and self._peek == "\n":
- i = 2
- self._start += 1
-
- self._col = 1
- self._line += 1
+ if not (self._char == "\r" and self._peek == "\n"):
+ self._col = 1
+ self._line += 1
else:
self._col += i
@@ -1268,13 +1268,27 @@ class Tokenizer(metaclass=_Tokenizer):
return True
self._advance()
- tag = "" if self._char == end else self._extract_string(end)
+
+ if self._char == end:
+ tag = ""
+ else:
+ tag = self._extract_string(
+ end,
+ unescape_sequences=False,
+ raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
+ )
+
+ if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER:
+ self._advance(-len(tag))
+ self._add(self.HEREDOC_STRING_ALTERNATIVE)
+ return True
+
end = f"{start}{tag}{end}"
else:
return False
self._advance(len(start))
- text = self._extract_string(end)
+ text = self._extract_string(end, unescape_sequences=token_type != TokenType.RAW_STRING)
if base:
try:
@@ -1289,7 +1303,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan_identifier(self, identifier_end: str) -> None:
self._advance()
- text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
+ text = self._extract_string(identifier_end, escapes=self._IDENTIFIER_ESCAPES)
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
@@ -1306,13 +1320,30 @@ class Tokenizer(metaclass=_Tokenizer):
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
- def _extract_string(self, delimiter: str, escapes=None) -> str:
+ def _extract_string(
+ self,
+ delimiter: str,
+ escapes: t.Optional[t.Set[str]] = None,
+ unescape_sequences: bool = True,
+ raise_unmatched: bool = True,
+ ) -> str:
text = ""
delim_size = len(delimiter)
escapes = self._STRING_ESCAPES if escapes is None else escapes
while True:
if (
+ unescape_sequences
+ and self.dialect.UNESCAPED_SEQUENCES
+ and self._peek
+ and self._char in self.STRING_ESCAPES
+ ):
+ unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get(self._char + self._peek)
+ if unescaped_sequence:
+ self._advance(2)
+ text += unescaped_sequence
+ continue
+ if (
self._char in escapes
and (self._peek == delimiter or self._peek in escapes)
and (self._char not in self._QUOTES or self._char == self._peek)
@@ -1333,18 +1364,10 @@ class Tokenizer(metaclass=_Tokenizer):
break
if self._end:
- raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
+ if not raise_unmatched:
+ return text + self._char
- if (
- self.dialect.ESCAPE_SEQUENCES
- and self._peek
- and self._char in self.STRING_ESCAPES
- ):
- escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek)
- if escaped_sequence:
- self._advance(2)
- text += escaped_sequence
- continue
+ raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
current = self._current - 1
self._advance(alnum=True)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 04c1f7b..f44c18c 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -447,7 +447,7 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
if inner_with.recursive:
top_level_with.set("recursive", True)
- top_level_with.expressions.extend(inner_with.expressions)
+ top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
return expression
@@ -464,7 +464,7 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
):
node.replace(node.neq(0))
- for node, *_ in expression.walk():
+ for node in expression.walk():
ensure_bools(node, _ensure_bool)
return expression
@@ -561,9 +561,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp
def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
- """
- Convert struct arguments to aliases: STRUCT(1 AS y) .
- """
+ """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
if isinstance(expression, exp.Struct):
expression.set(
"expressions",