summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py20
-rw-r--r--sqlglot/_typing.py8
-rw-r--r--sqlglot/dataframe/sql/_typing.py (renamed from sqlglot/dataframe/sql/_typing.pyi)4
-rw-r--r--sqlglot/dataframe/sql/dataframe.py12
-rw-r--r--sqlglot/dataframe/sql/operations.py2
-rw-r--r--sqlglot/dataframe/sql/session.py20
-rw-r--r--sqlglot/dataframe/sql/util.py2
-rw-r--r--sqlglot/dialects/bigquery.py97
-rw-r--r--sqlglot/dialects/clickhouse.py279
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py99
-rw-r--r--sqlglot/dialects/drill.py8
-rw-r--r--sqlglot/dialects/duckdb.py33
-rw-r--r--sqlglot/dialects/hive.py38
-rw-r--r--sqlglot/dialects/mysql.py64
-rw-r--r--sqlglot/dialects/oracle.py10
-rw-r--r--sqlglot/dialects/postgres.py45
-rw-r--r--sqlglot/dialects/presto.py94
-rw-r--r--sqlglot/dialects/redshift.py58
-rw-r--r--sqlglot/dialects/snowflake.py111
-rw-r--r--sqlglot/dialects/spark.py6
-rw-r--r--sqlglot/dialects/spark2.py96
-rw-r--r--sqlglot/dialects/sqlite.py19
-rw-r--r--sqlglot/dialects/starrocks.py10
-rw-r--r--sqlglot/dialects/tableau.py39
-rw-r--r--sqlglot/dialects/teradata.py14
-rw-r--r--sqlglot/dialects/trino.py2
-rw-r--r--sqlglot/dialects/tsql.py90
-rw-r--r--sqlglot/diff.py17
-rw-r--r--sqlglot/executor/__init__.py16
-rw-r--r--sqlglot/executor/env.py15
-rw-r--r--sqlglot/executor/python.py19
-rw-r--r--sqlglot/expressions.py1354
-rw-r--r--sqlglot/generator.py368
-rw-r--r--sqlglot/helper.py47
-rw-r--r--sqlglot/lineage.py28
-rw-r--r--sqlglot/optimizer/canonicalize.py10
-rw-r--r--sqlglot/optimizer/eliminate_ctes.py39
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py19
-rw-r--r--sqlglot/optimizer/expand_laterals.py34
-rw-r--r--sqlglot/optimizer/expand_multi_table_selects.py24
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py2
-rw-r--r--sqlglot/optimizer/lower_identities.py88
-rw-r--r--sqlglot/optimizer/merge_subqueries.py17
-rw-r--r--sqlglot/optimizer/normalize.py35
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py36
-rw-r--r--sqlglot/optimizer/optimize_joins.py7
-rw-r--r--sqlglot/optimizer/optimizer.py37
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py42
-rw-r--r--sqlglot/optimizer/pushdown_projections.py9
-rw-r--r--sqlglot/optimizer/qualify.py80
-rw-r--r--sqlglot/optimizer/qualify_columns.py221
-rw-r--r--sqlglot/optimizer/qualify_tables.py45
-rw-r--r--sqlglot/optimizer/scope.py43
-rw-r--r--sqlglot/optimizer/simplify.py32
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py23
-rw-r--r--sqlglot/parser.py896
-rw-r--r--sqlglot/planner.py17
-rw-r--r--sqlglot/schema.py201
-rw-r--r--sqlglot/tokens.py246
-rw-r--r--sqlglot/transforms.py44
61 files changed, 3221 insertions, 2172 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index f7440e0..8fb623a 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -6,6 +6,7 @@
from __future__ import annotations
+import logging
import typing as t
from sqlglot import expressions as exp
@@ -45,12 +46,19 @@ from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
if t.TYPE_CHECKING:
+ from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType as DialectType
- T = t.TypeVar("T", bound=Expression)
+logger = logging.getLogger("sqlglot")
-__version__ = "12.2.0"
+try:
+ from sqlglot._version import __version__, __version_tuple__
+except ImportError:
+ logger.error(
+ "Unable to set __version__, run `pip install -e .` or `python setup.py develop` first."
+ )
+
pretty = False
"""Whether to format generated SQL by default."""
@@ -79,9 +87,9 @@ def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expre
def parse_one(
sql: str,
read: None = None,
- into: t.Type[T] = ...,
+ into: t.Type[E] = ...,
**opts,
-) -> T:
+) -> E:
...
@@ -89,9 +97,9 @@ def parse_one(
def parse_one(
sql: str,
read: DialectType,
- into: t.Type[T],
+ into: t.Type[E],
**opts,
-) -> T:
+) -> E:
...
diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py
new file mode 100644
index 0000000..2acbbf7
--- /dev/null
+++ b/sqlglot/_typing.py
@@ -0,0 +1,8 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+
+E = t.TypeVar("E", bound="sqlglot.exp.Expression")
+T = t.TypeVar("T")
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.py
index 1682ec1..fb46026 100644
--- a/sqlglot/dataframe/sql/_typing.pyi
+++ b/sqlglot/dataframe/sql/_typing.py
@@ -11,6 +11,8 @@ if t.TYPE_CHECKING:
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrName = t.Union[Column, str]
-ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ColumnOrLiteral = t.Union[
+ Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime
+]
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index f3a6f6f..3fc9232 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -127,7 +127,7 @@ class DataFrame:
sequence_id: t.Optional[str] = None,
**kwargs,
) -> t.Tuple[exp.CTE, str]:
- name = self.spark._random_name
+ name = self._create_hash_from_expression(expression)
expression_to_cte = expression.copy()
expression_to_cte.set("with", None)
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
@@ -263,7 +263,7 @@ class DataFrame:
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
@classmethod
- def _create_hash_from_expression(cls, expression: exp.Select):
+ def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
value = expression.sql(dialect="spark").encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
@@ -299,7 +299,7 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- select_expression = optimize_func(select_expression, identify="always")
+ select_expression = t.cast(exp.Select, optimize_func(select_expression))
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@@ -570,9 +570,9 @@ class DataFrame:
r_expressions.append(l_column)
r_columns_unused.remove(l_column)
else:
- r_expressions.append(exp.alias_(exp.Null(), l_column))
+ r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
for r_column in r_columns_unused:
- l_expressions.append(exp.alias_(exp.Null(), r_column))
+ l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
r_expressions.append(r_column)
r_df = (
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
@@ -761,7 +761,7 @@ class DataFrame:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
if isinstance(existing_column, exp.Column):
- existing_column.replace(exp.alias_(existing_column.copy(), new))
+ existing_column.replace(exp.alias_(existing_column, new))
else:
existing_column.set("alias", exp.to_identifier(new))
return self.copy(expression=expression)
diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py
index d51335c..e4c106b 100644
--- a/sqlglot/dataframe/sql/operations.py
+++ b/sqlglot/dataframe/sql/operations.py
@@ -41,7 +41,7 @@ def operation(op: Operation):
self.last_op = Operation.NO_OP
last_op = self.last_op
new_op = op if op != Operation.NO_OP else last_op
- if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
+ if new_op < last_op or (last_op == new_op == Operation.SELECT):
self = self._convert_leaf_to_cte()
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index af589b0..b883359 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -87,15 +87,13 @@ class SparkSession:
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
- expressions=[
- exp.Values(
- expressions=data_expressions,
- alias=exp.TableAlias(
- this=exp.to_identifier(self._auto_incrementing_name),
- columns=[exp.to_identifier(col_name) for col_name in column_mapping],
- ),
+ this=exp.Values(
+ expressions=data_expressions,
+ alias=exp.TableAlias(
+ this=exp.to_identifier(self._auto_incrementing_name),
+ columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
- ],
+ ),
),
}
@@ -128,10 +126,6 @@ class SparkSession:
return name
@property
- def _random_name(self) -> str:
- return "r" + uuid.uuid4().hex
-
- @property
def _random_branch_id(self) -> str:
id = self._random_id
self.known_branch_ids.add(id)
@@ -145,7 +139,7 @@ class SparkSession:
@property
def _random_id(self) -> str:
- id = self._random_name
+ id = "r" + uuid.uuid4().hex
self.known_ids.add(id)
return id
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
index 575d18a..4b9fbb1 100644
--- a/sqlglot/dataframe/sql/util.py
+++ b/sqlglot/dataframe/sql/util.py
@@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T
if not expression.args.get("joins"):
return []
- left_table = expression.args["from"].args["expressions"][0]
+ left_table = expression.args["from"].this
other_tables = [join.this for join in expression.args["joins"]]
return [left_table] + other_tables
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 9705b35..1a58337 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,5 +1,3 @@
-"""Supports BigQuery Standard SQL."""
-
from __future__ import annotations
import re
@@ -18,11 +16,9 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
-from sqlglot.helper import seq_get
+from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
-E = t.TypeVar("E", bound=exp.Expression)
-
def _date_add_sql(
data_type: str, kind: str
@@ -96,19 +92,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
if isinstance(expression, exp.Select):
- unnests = {
- unnest.alias
- for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
- if isinstance(unnest, exp.Unnest) and unnest.alias
- }
-
- if unnests:
- expression = expression.copy()
-
- for select in expression.expressions:
- for column in select.find_all(exp.Column):
- if column.table in unnests:
- column.set("table", None)
+ for unnest in expression.find_all(exp.Unnest):
+ if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
+ for select in expression.selects:
+ for column in select.find_all(exp.Column):
+ if column.table == unnest.alias:
+ column.set("table", None)
return expression
@@ -127,16 +116,20 @@ class BigQuery(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- QUOTES = [
- (prefix + quote, quote) if prefix else quote
- for quote in ["'", '"', '"""', "'''"]
- for prefix in ["", "r", "R"]
- ]
+ QUOTES = ["'", '"', '"""', "'''"]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
+
HEX_STRINGS = [("0x", ""), ("0X", "")]
- BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
+
+ BYTE_STRINGS = [
+ (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B")
+ ]
+
+ RAW_STRINGS = [
+ (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R")
+ ]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -144,11 +137,11 @@ class BigQuery(Dialect):
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
+ "BYTES": TokenType.BINARY,
"DECLARE": TokenType.COMMAND,
- "GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
- "BYTES": TokenType.BINARY,
+ "RECORD": TokenType.STRUCT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@@ -161,7 +154,7 @@ class BigQuery(Dialect):
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
@@ -191,28 +184,28 @@ class BigQuery(Dialect):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
- **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
+ **parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
}
NESTED_TYPE_TOKENS = {
- *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
+ *parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
ID_VAR_TOKENS = {
- *parser.Parser.ID_VAR_TOKENS, # type: ignore
+ *parser.Parser.ID_VAR_TOKENS,
TokenType.VALUES,
}
PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS, # type: ignore
+ **parser.Parser.PROPERTY_PARSERS,
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
@@ -220,19 +213,50 @@ class BigQuery(Dialect):
}
CONSTRAINT_PARSERS = {
- **parser.Parser.CONSTRAINT_PARSERS, # type: ignore
+ **parser.Parser.CONSTRAINT_PARSERS,
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
+ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
+ this = super()._parse_table_part(schema=schema)
+
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
+ if isinstance(this, exp.Identifier):
+ table_name = this.name
+ while self._match(TokenType.DASH, advance=False) and self._next:
+ self._advance(2)
+ table_name += f"-{self._prev.text}"
+
+ this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
+
+ return this
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ table = super()._parse_table_parts(schema=schema)
+ if isinstance(table.this, exp.Identifier) and "." in table.name:
+ catalog, db, this, *rest = (
+ t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
+ for x in split_num_words(table.name, ".", 3)
+ )
+
+ if rest and this:
+ this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
+
+ table = exp.Table(this=this, db=db, catalog=catalog)
+
+ return table
+
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
+ RENAME_TABLE_WITH_DB = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
+ exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
@@ -259,6 +283,7 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
+ exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -274,7 +299,7 @@ class BigQuery(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BINARY: "BYTES",
@@ -297,7 +322,7 @@ class BigQuery(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 2a49066..c8a9525 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -3,11 +3,16 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
-from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
+from sqlglot.dialects.dialect import (
+ Dialect,
+ inline_array_sql,
+ no_pivot_sql,
+ rename_func,
+ var_map_sql,
+)
from sqlglot.errors import ParseError
-from sqlglot.helper import ensure_list, seq_get
from sqlglot.parser import parse_var_map
-from sqlglot.tokens import TokenType
+from sqlglot.tokens import Token, TokenType
def _lower_func(sql: str) -> str:
@@ -28,65 +33,122 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
- "GLOBAL": TokenType.GLOBAL,
- "DATETIME64": TokenType.DATETIME,
+ "ATTACH": TokenType.COMMAND,
+ "DATETIME64": TokenType.DATETIME64,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
- "INT8": TokenType.TINYINT,
- "UINT8": TokenType.UTINYINT,
+ "GLOBAL": TokenType.GLOBAL,
+ "INT128": TokenType.INT128,
"INT16": TokenType.SMALLINT,
- "UINT16": TokenType.USMALLINT,
+ "INT256": TokenType.INT256,
"INT32": TokenType.INT,
- "UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
- "UINT64": TokenType.UBIGINT,
- "INT128": TokenType.INT128,
+ "INT8": TokenType.TINYINT,
+ "MAP": TokenType.MAP,
+ "TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
- "INT256": TokenType.INT256,
+ "UINT16": TokenType.USMALLINT,
"UINT256": TokenType.UINT256,
- "TUPLE": TokenType.STRUCT,
+ "UINT32": TokenType.UINT,
+ "UINT64": TokenType.UBIGINT,
+ "UINT8": TokenType.UTINYINT,
}
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
- "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
- this=seq_get(args, 0),
- time=seq_get(args, 1),
- decay=seq_get(params, 0),
- ),
- "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
- this=seq_get(args, 0), size=seq_get(params, 0)
- ),
- "HISTOGRAM": lambda params, args: exp.Histogram(
- this=seq_get(args, 0), bins=seq_get(params, 0)
- ),
+ **parser.Parser.FUNCTIONS,
+ "ANY": exp.AnyValue.from_arg_list,
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
- "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
- "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
- "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
+ "UNIQ": exp.ApproxDistinct.from_arg_list,
+ }
+
+ FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
+
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS,
+ "QUANTILE": lambda self: self._parse_quantile(),
}
- FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("MATCH")
+ NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy()
+ NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY)
+
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
and self._parse_in(this, is_global=True),
}
- JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
+ # The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to
+ # the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler.
+ COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy()
+ COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER)
+
+ JOIN_KINDS = {
+ *parser.Parser.JOIN_KINDS,
+ TokenType.ANY,
+ TokenType.ASOF,
+ TokenType.ANTI,
+ TokenType.SEMI,
+ }
- TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
+ TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
+ TokenType.ANY,
+ TokenType.ASOF,
+ TokenType.SEMI,
+ TokenType.ANTI,
+ TokenType.SETTINGS,
+ TokenType.FORMAT,
+ }
LOG_DEFAULTS_TO_LN = True
- def _parse_in(
- self, this: t.Optional[exp.Expression], is_global: bool = False
- ) -> exp.Expression:
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ "settings": lambda self: self._parse_csv(self._parse_conjunction)
+ if self._match(TokenType.SETTINGS)
+ else None,
+ "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
+ }
+
+ def _parse_conjunction(self) -> t.Optional[exp.Expression]:
+ this = super()._parse_conjunction()
+
+ if self._match(TokenType.PLACEHOLDER):
+ return self.expression(
+ exp.If,
+ this=this,
+ true=self._parse_conjunction(),
+ false=self._match(TokenType.COLON) and self._parse_conjunction(),
+ )
+
+ return this
+
+ def _parse_placeholder(self) -> t.Optional[exp.Expression]:
+ """
+ Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
+ https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
+ """
+ if not self._match(TokenType.L_BRACE):
+ return None
+
+ this = self._parse_id_var()
+ self._match(TokenType.COLON)
+ kind = self._parse_types(check_func=False) or (
+ self._match_text_seq("IDENTIFIER") and "Identifier"
+ )
+
+ if not kind:
+ self.raise_error("Expecting a placeholder type or 'Identifier' for tables")
+ elif not self._match(TokenType.R_BRACE):
+ self.raise_error("Expecting }")
+
+ return self.expression(exp.Placeholder, this=this, kind=kind)
+
+ def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In:
this = super()._parse_in(this)
this.set("is_global", is_global)
return this
@@ -120,81 +182,142 @@ class ClickHouse(Dialect):
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
+ def _parse_join_side_and_kind(
+ self,
+ ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
+ is_global = self._match(TokenType.GLOBAL) and self._prev
+ kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
+ if kind_pre:
+ kind = self._match_set(self.JOIN_KINDS) and self._prev
+ side = self._match_set(self.JOIN_SIDES) and self._prev
+ return is_global, side, kind
+ return (
+ is_global,
+ self._match_set(self.JOIN_SIDES) and self._prev,
+ self._match_set(self.JOIN_KINDS) and self._prev,
+ )
+
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ join = super()._parse_join(skip_join_token)
+
+ if join:
+ join.set("global", join.args.pop("natural", None))
+ return join
+
+ def _parse_function(
+ self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
+ ) -> t.Optional[exp.Expression]:
+ func = super()._parse_function(functions, anonymous)
+
+ if isinstance(func, exp.Anonymous):
+ params = self._parse_func_params(func)
+
+ if params:
+ return self.expression(
+ exp.ParameterizedAgg,
+ this=func.this,
+ expressions=func.expressions,
+ params=params,
+ )
+
+ return func
+
+ def _parse_func_params(
+ self, this: t.Optional[exp.Func] = None
+ ) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
+ return self._parse_csv(self._parse_lambda)
+ if self._match(TokenType.L_PAREN):
+ params = self._parse_csv(self._parse_lambda)
+ self._match_r_paren(this)
+ return params
+ return None
+
+ def _parse_quantile(self) -> exp.Quantile:
+ this = self._parse_lambda()
+ params = self._parse_func_params()
+ if params:
+ return self.expression(exp.Quantile, this=params[0], quantile=this)
+ return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
+
+ def _parse_wrapped_id_vars(
+ self, optional: bool = False
+ ) -> t.List[t.Optional[exp.Expression]]:
+ return super()._parse_wrapped_id_vars(optional=True)
+
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.NULLABLE: "Nullable",
- exp.DataType.Type.DATETIME: "DateTime64",
- exp.DataType.Type.MAP: "Map",
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
+ exp.DataType.Type.BIGINT: "Int64",
+ exp.DataType.Type.DATETIME64: "DateTime64",
+ exp.DataType.Type.DOUBLE: "Float64",
+ exp.DataType.Type.FLOAT: "Float32",
+ exp.DataType.Type.INT: "Int32",
+ exp.DataType.Type.INT128: "Int128",
+ exp.DataType.Type.INT256: "Int256",
+ exp.DataType.Type.MAP: "Map",
+ exp.DataType.Type.NULLABLE: "Nullable",
+ exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
- exp.DataType.Type.UTINYINT: "UInt8",
- exp.DataType.Type.SMALLINT: "Int16",
- exp.DataType.Type.USMALLINT: "UInt16",
- exp.DataType.Type.INT: "Int32",
- exp.DataType.Type.UINT: "UInt32",
- exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.UBIGINT: "UInt64",
- exp.DataType.Type.INT128: "Int128",
+ exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.UINT128: "UInt128",
- exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.UINT256: "UInt256",
- exp.DataType.Type.FLOAT: "Float32",
- exp.DataType.Type.DOUBLE: "Float64",
+ exp.DataType.Type.USMALLINT: "UInt16",
+ exp.DataType.Type.UTINYINT: "UInt8",
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
+ exp.AnyValue: rename_func("any"),
+ exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
- exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
+ exp.CastToStrType: rename_func("CAST"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
- exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
- exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
- exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
- exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
- exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.Pivot: no_pivot_sql,
+ exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
+ + f"({self.sql(e, 'this')})",
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
}
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
-
- def _param_args_sql(
- self,
- expression: exp.Expression,
- param_names: str | t.List[str],
- arg_names: str | t.List[str],
- ) -> str:
- params = self.format_args(
- *(
- arg
- for name in ensure_list(param_names)
- for arg in ensure_list(expression.args.get(name))
- )
- )
- args = self.format_args(
- *(
- arg
- for name in ensure_list(arg_names)
- for arg in ensure_list(expression.args.get(name))
- )
- )
- return f"({params})({args})"
+ GROUPINGS_SEP = ""
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
return super().cte_sql(expression)
+
+ def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
+ return super().after_limit_modifiers(expression) + [
+ self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
+ if expression.args.get("settings")
+ else "",
+ self.seg("FORMAT ") + self.sql(expression, "format")
+ if expression.args.get("format")
+ else "",
+ ]
+
+ def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
+ params = self.expressions(expression, "params", flat=True)
+ return self.func(expression.name, *expression.expressions) + f"({params})"
+
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
+ return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 51112a0..2149aca 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -25,7 +25,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
- **Spark.Generator.TRANSFORMS, # type: ignore
+ **Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 71269f2..890a3c3 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -8,10 +8,16 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
-from sqlglot.tokens import Token, Tokenizer
+from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
-E = t.TypeVar("E", bound=exp.Expression)
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
+
+# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
+# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
+RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
class Dialects(str, Enum):
@@ -42,6 +48,19 @@ class Dialects(str, Enum):
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
+ def __eq__(cls, other: t.Any) -> bool:
+ if cls is other:
+ return True
+ if isinstance(other, str):
+ return cls is cls.get(other)
+ if isinstance(other, Dialect):
+ return cls is type(other)
+
+ return False
+
+ def __hash__(cls) -> int:
+ return hash(cls.__name__.lower())
+
@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@@ -70,17 +89,20 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
- klass.bit_start, klass.bit_end = seq_get(
- list(klass.tokenizer_class._BIT_STRINGS.items()), 0
- ) or (None, None)
-
- klass.hex_start, klass.hex_end = seq_get(
- list(klass.tokenizer_class._HEX_STRINGS.items()), 0
- ) or (None, None)
+ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
+ return next(
+ (
+ (s, e)
+ for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
+ if t == token_type
+ ),
+ (None, None),
+ )
- klass.byte_start, klass.byte_end = seq_get(
- list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
- ) or (None, None)
+ klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
+ klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
+ klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
+ klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
return klass
@@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
parser_class = None
generator_class = None
+ def __eq__(self, other: t.Any) -> bool:
+ return type(self) == other
+
+ def __hash__(self) -> int:
+ return hash(type(self))
+
@classmethod
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
@@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
+ "raw_start": self.raw_start,
+ "raw_end": self.raw_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
@@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
- return self.sql(expression)
+ return ""
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
@@ -328,7 +358,7 @@ def var_map_sql(
def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
-) -> t.Callable[[t.Sequence], E]:
+) -> t.Callable[[t.List], E]:
"""Helper used for time expressions.
Args:
@@ -340,7 +370,7 @@ def format_time_lambda(
A callable that can be used to return the appropriately formatted time expression.
"""
- def _format_time(args: t.Sequence):
+ def _format_time(args: t.List):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
@@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
-) -> t.Callable[[t.Sequence], E]:
- def inner_func(args: t.Sequence) -> E:
+) -> t.Callable[[t.List], E]:
+ def inner_func(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
- unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
+ unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
@@ -390,8 +420,8 @@ def parse_date_delta(
def parse_date_delta_with_interval(
expression_class: t.Type[E],
-) -> t.Callable[[t.Sequence], t.Optional[E]]:
- def func(args: t.Sequence) -> t.Optional[E]:
+) -> t.Callable[[t.List], t.Optional[E]]:
+ def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
return None
@@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
return func
-def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
+def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
@@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
-def locate_to_strposition(args: t.Sequence) -> exp.Expression:
+def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
-def str_to_time_sql(self, expression: exp.Expression) -> str:
+def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
@@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
+
+
+# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
+def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
+ names = []
+ for agg in aggregations:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
+
+ return names
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 7ad555e..924b979 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -95,7 +95,7 @@ class Drill(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
@@ -108,7 +108,7 @@ class Drill(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
@@ -121,13 +121,13 @@ class Drill(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index bce956e..662882d 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -11,9 +11,9 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
no_comment_column_constraint_sql,
- no_pivot_sql,
no_properties_sql,
no_safe_divide_sql,
+ pivot_column_names,
rename_func,
str_position_sql,
str_to_time_sql,
@@ -31,10 +31,11 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
-def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
+def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
+ op = "+" if isinstance(expression, exp.DateAdd) else "-"
+ return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
@@ -50,11 +51,11 @@ def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str
return f"ARRAY_SORT({this})"
-def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
+def _sort_array_reverse(args: t.List) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
-def _parse_date_diff(args: t.Sequence) -> exp.Expression:
+def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@@ -89,11 +90,14 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
+ null_ordering = "nulls_are_last"
+
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~": TokenType.RLIKE,
":=": TokenType.EQ,
+ "//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
"BPCHAR": TokenType.TEXT,
@@ -104,6 +108,7 @@ class DuckDB(Dialect):
"INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
"NUMERIC": TokenType.DOUBLE,
+ "PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
@@ -114,8 +119,7 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
- "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ **parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
@@ -152,11 +156,17 @@ class DuckDB(Dialect):
TokenType.UTINYINT,
}
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
+ if len(aggregations) == 1:
+ return super()._pivot_column_names(aggregations)
+ return pivot_column_names(aggregations, dialect="duckdb")
+
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
+ RENAME_TABLE_WITH_DB = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -175,7 +185,8 @@ class DuckDB(Dialect):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
- exp.DateAdd: _date_add_sql,
+ exp.DateAdd: _date_delta_sql,
+ exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
@@ -183,13 +194,13 @@ class DuckDB(Dialect):
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
+ exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
- exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpExtract: _regexp_extract_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@@ -232,11 +243,11 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 871a180..fbd626a 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
-def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
- this = self.sql(expression, "this")
- table = self.sql(expression, "table")
- columns = self.sql(expression, "columns")
- return f"{this} ON TABLE {table} {columns}"
-
-
class Hive(Dialect):
alias_post_tablesample = True
@@ -225,8 +218,7 @@ class Hive(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
- "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ **parser.Parser.FUNCTIONS,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
@@ -271,21 +263,29 @@ class Hive(Dialect):
}
PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS, # type: ignore
+ **parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ "distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
+ "sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
+ "cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
+ }
+
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
+ INDEX_ON = "ON TABLE"
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
@@ -294,7 +294,7 @@ class Hive(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[
@@ -319,7 +319,6 @@ class Hive(Dialect):
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
- exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
@@ -342,7 +341,6 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@@ -363,14 +361,13 @@ class Hive(Dialect):
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
- exp.National: lambda self, e: self.sql(e, "this"),
+ exp.National: lambda self, e: self.national_sql(e, prefix=""),
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
- exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@@ -396,3 +393,10 @@ class Hive(Dialect):
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)
+
+ def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
+ return super().after_having_modifiers(expression) + [
+ self.sql(expression, "distribute"),
+ self.sql(expression, "sort"),
+ self.sql(expression, "cluster"),
+ ]
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5342624..2b41860 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
no_paren_current_date_sql,
+ no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
@@ -21,14 +24,14 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _show_parser(*args, **kwargs):
- def _parse(self):
+def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]:
+ def _parse(self: MySQL.Parser) -> exp.Show:
return self._parse_show_mysql(*args, **kwargs)
return _parse
-def _date_trunc_sql(self, expression):
+def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
@@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression):
return f"STR_TO_DATE({concat}, '{date_format}')"
-def _str_to_date(args):
+def _str_to_date(args: t.List) -> exp.StrToDate:
date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self, expression):
+def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
-def _trim_sql(self, expression):
+def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
@@ -79,8 +82,8 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind):
- def func(self, expression):
+def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return (
@@ -175,10 +178,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
- FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
@@ -191,7 +194,7 @@ class MySQL(Dialect):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -199,13 +202,8 @@ class MySQL(Dialect):
),
}
- PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS, # type: ignore
- "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
- }
-
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
}
@@ -286,7 +284,13 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
- def _parse_show_mysql(self, this, target=False, full=None, global_=None):
+ def _parse_show_mysql(
+ self,
+ this: str,
+ target: bool | str = False,
+ full: t.Optional[bool] = None,
+ global_: t.Optional[bool] = None,
+ ) -> exp.Show:
if target:
if isinstance(target, str):
self._match_text_seq(target)
@@ -342,10 +346,12 @@ class MySQL(Dialect):
offset=offset,
limit=limit,
mutex=mutex,
- **{"global": global_},
+ **{"global": global_}, # type: ignore
)
- def _parse_oldstyle_limit(self):
+ def _parse_oldstyle_limit(
+ self,
+ ) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]:
limit = None
offset = None
if self._match_text_seq("LIMIT"):
@@ -355,23 +361,20 @@ class MySQL(Dialect):
elif len(parts) == 2:
limit = parts[1]
offset = parts[0]
+
return offset, limit
- def _parse_set_item_charset(self, kind):
+ def _parse_set_item_charset(self, kind: str) -> exp.Expression:
this = self._parse_string() or self._parse_id_var()
+ return self.expression(exp.SetItem, this=this, kind=kind)
- return self.expression(
- exp.SetItem,
- this=this,
- kind=kind,
- )
-
- def _parse_set_item_names(self):
+ def _parse_set_item_names(self) -> exp.Expression:
charset = self._parse_string() or self._parse_id_var()
if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
+
return self.expression(
exp.SetItem,
this=charset,
@@ -386,7 +389,7 @@ class MySQL(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
@@ -403,6 +406,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@@ -422,7 +426,7 @@ class MySQL(Dialect):
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index c8af1c6..7722753 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _parse_xml_table(self) -> exp.XMLTable:
+def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
this = self._parse_string()
passing = None
@@ -66,7 +66,7 @@ class Oracle(Dialect):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
}
@@ -107,7 +107,7 @@ class Oracle(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@@ -122,7 +122,7 @@ class Oracle(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
@@ -143,7 +143,7 @@ class Oracle(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 2132778..ab61880 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
no_paren_current_date_sql,
+ no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
@@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = {
}
-def _date_add_sql(kind):
- def func(self, expression):
+def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
@@ -51,7 +52,7 @@ def _date_add_sql(kind):
return func
-def _date_diff_sql(self, expression):
+def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
@@ -77,7 +78,7 @@ def _date_diff_sql(self, expression):
return f"CAST({unit} AS BIGINT)"
-def _substring_sql(self, expression):
+def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
@@ -88,7 +89,7 @@ def _substring_sql(self, expression):
return f"SUBSTRING({this}{from_part}{for_part})"
-def _string_agg_sql(self, expression):
+def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@@ -102,13 +103,13 @@ def _string_agg_sql(self, expression):
return f"STRING_AGG({self.format_args(this, separator)}{order})"
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
-def _auto_increment_to_serial(expression):
+def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
@@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression):
return expression
-def _serial_to_generated(expression):
+def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.SERIAL:
@@ -144,6 +145,7 @@ def _serial_to_generated(expression):
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
+
if notnull not in constraints:
constraints.insert(0, notnull)
if generated not in constraints:
@@ -152,7 +154,7 @@ def _serial_to_generated(expression):
return expression
-def _generate_series(args):
+def _generate_series(args: t.List) -> exp.Expression:
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
@@ -168,11 +170,12 @@ def _generate_series(args):
return exp.GenerateSeries.from_arg_list(args)
-def _to_timestamp(args):
+def _to_timestamp(args: t.List) -> exp.Expression:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
+
# https://www.postgresql.org/docs/current/functions-formatting.html
return format_time_lambda(exp.StrToTime, "postgres")(args)
@@ -255,7 +258,7 @@ class Postgres(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
@@ -271,7 +274,7 @@ class Postgres(Dialect):
}
BITWISE = {
- **parser.Parser.BITWISE, # type: ignore
+ **parser.Parser.BITWISE,
TokenType.HASH: exp.BitwiseXor,
}
@@ -280,7 +283,7 @@ class Postgres(Dialect):
}
RANGE_PARSERS = {
- **parser.Parser.RANGE_PARSERS, # type: ignore
+ **parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
@@ -303,14 +306,14 @@ class Postgres(Dialect):
return self.expression(exp.Extract, this=part, expression=value)
class Generator(generator.Generator):
- INTERVAL_ALLOWS_PLURAL_FORM = False
+ SINGLE_STRING_INTERVAL = True
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@@ -320,14 +323,9 @@ class Postgres(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
- exp.ColumnDef: transforms.preprocess(
- [
- _auto_increment_to_serial,
- _serial_to_generated,
- ],
- ),
+ exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
@@ -348,6 +346,7 @@ class Postgres(Dialect):
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
+ exp.Pivot: no_pivot_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
@@ -369,7 +368,7 @@ class Postgres(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 6133a27..52a04a4 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
no_ilike_sql,
+ no_pivot_sql,
no_safe_divide_sql,
rename_func,
struct_extract_sql,
@@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)
-def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
- start = expression.args["start"]
- end = expression.args["end"]
- step = expression.args.get("step")
-
- target_type = None
-
- if isinstance(start, exp.Cast):
- target_type = start.to
- elif isinstance(end, exp.Cast):
- target_type = end.to
-
- if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
- to = target_type.copy()
-
- if target_type is start.to:
- end = exp.Cast(this=end, to=to)
- else:
- start = exp.Cast(this=start, to=to)
-
- sql = self.func("SEQUENCE", start, end, step)
- if isinstance(expression.parent, exp.Table):
- sql = f"UNNEST({sql})"
-
- return sql
-
-
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
-def _approx_percentile(args: t.Sequence) -> exp.Expression:
+def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression:
return exp.ApproxQuantile.from_arg_list(args)
-def _from_unixtime(args: t.Sequence) -> exp.Expression:
+def _from_unixtime(args: t.List) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
+def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Table):
+ if isinstance(expression.this, exp.GenerateSeries):
+ unnest = exp.Unnest(expressions=[expression.this])
+
+ if expression.alias:
+ return exp.alias_(
+ unnest,
+ alias="_u",
+ table=[expression.alias],
+ copy=False,
+ )
+ return unnest
+ return expression
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
- time_format = MySQL.time_format # type: ignore
- time_mapping = MySQL.time_mapping # type: ignore
+ time_format = MySQL.time_format
+ time_mapping = MySQL.time_mapping
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
+ "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"ROW": TokenType.STRUCT,
}
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
@@ -252,13 +243,13 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@@ -268,8 +259,9 @@ class Presto(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: _approx_distinct_sql,
+ exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
@@ -293,7 +285,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
- exp.GenerateSeries: _sequence_sql,
+ exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
@@ -301,10 +293,10 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
- exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
- exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@@ -320,8 +312,7 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
- exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
+ exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
@@ -336,6 +327,7 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
+ exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
@@ -351,3 +343,25 @@ class Presto(Dialect):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
+
+ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
+ start = expression.args["start"]
+ end = expression.args["end"]
+ step = expression.args.get("step")
+
+ if isinstance(start, exp.Cast):
+ target_type = start.to
+ elif isinstance(end, exp.Cast):
+ target_type = end.to
+ else:
+ target_type = None
+
+ if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
+ to = target_type.copy()
+
+ if target_type is start.to:
+ end = exp.Cast(this=end, to=to)
+ else:
+ start = exp.Cast(this=start, to=to)
+
+ return self.func("SEQUENCE", start, end, step)
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 1b7cf31..55e393a 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -8,21 +8,21 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _json_sql(self, e) -> str:
- return f'{self.sql(e, "this")}."{e.expression.name}"'
+def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
+ return f'{self.sql(expression, "this")}."{expression.expression.name}"'
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
- **Postgres.time_mapping, # type: ignore
+ **Postgres.time_mapping,
"MON": "%b",
"HH": "%H",
}
class Parser(Postgres.Parser):
FUNCTIONS = {
- **Postgres.Parser.FUNCTIONS, # type: ignore
+ **Postgres.Parser.FUNCTIONS,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@@ -45,7 +45,7 @@ class Redshift(Postgres):
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.expressions
- and this.expressions[0] == exp.column("MAX")
+ and this.expressions[0].this == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
@@ -57,9 +57,7 @@ class Redshift(Postgres):
STRING_ESCAPES = ["\\"]
KEYWORDS = {
- **Postgres.Tokenizer.KEYWORDS, # type: ignore
- "GEOMETRY": TokenType.GEOMETRY,
- "GEOGRAPHY": TokenType.GEOGRAPHY,
+ **Postgres.Tokenizer.KEYWORDS,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
@@ -76,22 +74,22 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
- SINGLE_STRING_INTERVAL = True
+ RENAME_TABLE_WITH_DB = False
TYPE_MAPPING = {
- **Postgres.Generator.TYPE_MAPPING, # type: ignore
+ **Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
PROPERTIES_LOCATION = {
- **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
+ **Postgres.Generator.PROPERTIES_LOCATION,
exp.LikeProperty: exp.Properties.Location.POST_WITH,
}
TRANSFORMS = {
- **Postgres.Generator.TRANSFORMS, # type: ignore
+ **Postgres.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@@ -107,10 +105,13 @@ class Redshift(Postgres):
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
+ # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
+ TRANSFORMS.pop(exp.Pivot)
+
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
- RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
+ RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def values_sql(self, expression: exp.Values) -> str:
"""
@@ -120,37 +121,36 @@ class Redshift(Postgres):
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
- if not isinstance(expression.unnest().parent, exp.From):
+
+ # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
+ if not expression.find_ancestor(exp.From, exp.Join):
return super().values_sql(expression)
- rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
+
+ column_names = expression.alias and expression.args["alias"].columns
+
selects = []
+ rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
+
for i, row in enumerate(rows):
- if i == 0 and expression.alias:
+ if i == 0 and column_names:
row = [
exp.alias_(value, column_name)
- for value, column_name in zip(row, expression.args["alias"].args["columns"])
+ for value, column_name in zip(row, column_names)
]
+
selects.append(exp.Select(expressions=row))
- subquery_expression = selects[0]
+
+ subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
+
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
- def renametable_sql(self, expression: exp.RenameTable) -> str:
- """Redshift only supports defining the table name itself (not the db) when renaming tables"""
- expression = expression.copy()
- target_table = expression.this
- for arg in target_table.args:
- if arg != "this":
- target_table.set(arg, None)
- this = self.sql(expression, "this")
- return f"RENAME TO {this}"
-
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
@@ -162,6 +162,8 @@ class Redshift(Postgres):
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
+
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
+
return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 70dcaa9..756e8e9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import flatten, seq_get
+from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
+def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime(this=first_arg, scale=timescale)
+ from sqlglot.optimizer.simplify import simplify_literals
+
+ # The first argument might be an expression like 40 * 365 * 86400, so we try to
+ # reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
- if not isinstance(first_arg, Literal):
+ if not isinstance(simplify_literals(first_arg, root=True), Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime.from_arg_list(args)
+def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
+ expression = parser.parse_var_map(args)
+
+ if isinstance(expression, exp.StarMap):
+ return expression
+
+ return exp.Struct(
+ expressions=[
+ t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
+ ]
+ )
+
+
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
@@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args: t.Sequence) -> exp.Expression:
+def _div0_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
+def _zeroifnull_to_if(args: t.List) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
+def _nullifzero_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
+def _parse_convert_timezone(args: t.List) -> exp.Expression:
+ if len(args) == 3:
+ return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
+ return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
+
+
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@@ -177,17 +200,14 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
- QUOTED_PIVOT_COLUMNS = True
+ IDENTIFY_PIVOT_STRINGS = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
- "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone(
- this=seq_get(args, 1),
- zone=seq_get(args, 0),
- ),
+ "CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
@@ -202,7 +222,7 @@ class Snowflake(Dialect):
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
- "OBJECT_CONSTRUCT": parser.parse_var_map,
+ "OBJECT_CONSTRUCT": _parse_object_construct,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_ARRAY": exp.Array.from_arg_list,
@@ -224,7 +244,7 @@ class Snowflake(Dialect):
}
COLUMN_OPERATORS = {
- **parser.Parser.COLUMN_OPERATORS, # type: ignore
+ **parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@@ -232,14 +252,16 @@ class Snowflake(Dialect):
),
}
+ TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME}
+
RANGE_PARSERS = {
- **parser.Parser.RANGE_PARSERS, # type: ignore
+ **parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
- **parser.Parser.ALTER_PARSERS, # type: ignore
+ **parser.Parser.ALTER_PARSERS,
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
"SET": lambda self: self._parse_alter_table_set_tag(),
}
@@ -256,17 +278,20 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "CHAR VARYING": TokenType.VARCHAR,
+ "CHARACTER VARYING": TokenType.VARCHAR,
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
+ "MINUS": TokenType.EXCEPT,
+ "NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
- "MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
@@ -285,7 +310,7 @@ class Snowflake(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
@@ -299,6 +324,7 @@ class Snowflake(Dialect):
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.Extract: rename_func("DATE_PART"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@@ -312,6 +338,10 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Struct: lambda self, e: self.func(
+ "OBJECT_CONSTRUCT",
+ *(arg for expression in e.expressions for arg in expression.flatten()),
+ ),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.TimeToStr: lambda self, e: self.func(
@@ -326,7 +356,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -336,7 +366,7 @@ class Snowflake(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@@ -351,53 +381,10 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
- def values_sql(self, expression: exp.Values) -> str:
- """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
-
- We also want to make sure that after we find matches where we need to unquote a column that we prevent users
- from adding quotes to the column by using the `identify` argument when generating the SQL.
- """
- alias = expression.args.get("alias")
- if alias and alias.args.get("columns"):
- expression = expression.transform(
- lambda node: exp.Identifier(**{**node.args, "quoted": False})
- if isinstance(node, exp.Identifier)
- and isinstance(node.parent, exp.TableAlias)
- and node.arg_key == "columns"
- else node,
- )
- return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
- return super().values_sql(expression)
-
def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"
- def select_sql(self, expression: exp.Select) -> str:
- """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
- that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
- to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
- generating the SQL.
-
- Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
- expression. This might not be true in a case where the same column name can be sourced from another table that can
- properly quote but should be true in most cases.
- """
- values_identifiers = set(
- flatten(
- (v.args.get("alias") or exp.Alias()).args.get("columns", [])
- for v in expression.find_all(exp.Values)
- )
- )
- if values_identifiers:
- expression = expression.transform(
- lambda node: exp.Identifier(**{**node.args, "quoted": False})
- if isinstance(node, exp.Identifier) and node in values_identifiers
- else node,
- )
- return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
- return super().select_sql(expression)
-
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 939f2fd..b7d1641 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -7,10 +7,10 @@ from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
-def _parse_datediff(args: t.Sequence) -> exp.Expression:
+def _parse_datediff(args: t.List) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
- it at some point. Databricks also supports this variation (see below).
+ it at some point. Databricks also supports this variant (see below).
For example, in spark-sql (v3.3.1):
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
@@ -36,7 +36,7 @@ def _parse_datediff(args: t.Sequence) -> exp.Expression:
class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
- **Spark2.Parser.FUNCTIONS, # type: ignore
+ **Spark2.Parser.FUNCTIONS,
"DATEDIFF": _parse_datediff,
}
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 584671f..912b86b 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -3,7 +3,12 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parser, transforms
-from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
+from sqlglot.dialects.dialect import (
+ create_with_partitions_sql,
+ pivot_column_names,
+ rename_func,
+ trim_sql,
+)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
return f"MAP_FROM_ARRAYS({keys}, {values})"
-def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
+def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
@@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
raise ValueError("Improper scale for timestamp")
+def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
+ """
+ Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
+ pivoted source in a subquery with the same alias to preserve the query's semantics.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
+ >>> print(_unalias_pivot(expr).sql(dialect="spark"))
+ SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
+ """
+ if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
+ pivot = expression.this.args["pivots"][0]
+ if pivot.alias:
+ alias = pivot.args["alias"].pop()
+ return exp.From(
+ this=expression.this.replace(
+ exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
+ )
+ )
+
+ return expression
+
+
+def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
+ so we need to unqualify it.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
+ >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
+ SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
+ """
+ if isinstance(expression, exp.Pivot):
+ expression.args["field"].transform(
+ lambda node: exp.column(node.output_name, quoted=node.this.quoted)
+ if isinstance(node, exp.Column)
+ else node,
+ copy=False,
+ )
+
+ return expression
+
+
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS, # type: ignore
+ **Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
@@ -110,7 +161,7 @@ class Spark2(Hive):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -131,43 +182,21 @@ class Spark2(Hive):
kind="COLUMNS",
)
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
- if len(pivot_columns) == 1:
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
+ if len(aggregations) == 1:
return [""]
-
- names = []
- for agg in pivot_columns:
- if isinstance(agg, exp.Alias):
- names.append(agg.alias)
- else:
- """
- This case corresponds to aggregations without aliases being used as suffixes
- (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
- be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
- Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
-
- Moreover, function names are lowercased in order to mimic Spark's naming scheme.
- """
- agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
- )
- names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
-
- return names
+ return pivot_column_names(aggregations, dialect="spark")
class Generator(Hive.Generator):
TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING, # type: ignore
+ **Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
- **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
+ **Hive.Generator.PROPERTIES_LOCATION,
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
@@ -175,7 +204,7 @@ class Spark2(Hive):
}
TRANSFORMS = {
- **Hive.Generator.TRANSFORMS, # type: ignore
+ **Hive.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
@@ -188,11 +217,12 @@ class Spark2(Hive):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
+ exp.From: transforms.preprocess([_unalias_pivot]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
- exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
+ exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index f2efe32..56e7773 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
count_if_to_sum,
no_ilike_sql,
+ no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
@@ -14,7 +15,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
-def _date_add_sql(self, expression):
+def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
modifier = expression.expression
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
@@ -67,7 +68,7 @@ class SQLite(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
@@ -76,7 +77,7 @@ class SQLite(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@@ -98,7 +99,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
@@ -114,6 +115,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
+ exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
@@ -163,12 +165,15 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
# https://www.sqlite.org/lang_aggfunc.html#group_concat
- def groupconcat_sql(self, expression):
+ def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
this = expression.this
distinct = expression.find(exp.Distinct)
+
if distinct:
this = distinct.expressions[0]
- distinct = "DISTINCT "
+ distinct_sql = "DISTINCT "
+ else:
+ distinct_sql = ""
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
@@ -176,7 +181,7 @@ class SQLite(Dialect):
this = expression.this.this
separator = expression.args.get("separator")
- return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
+ return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})"
def least_sql(self, expression: exp.Least) -> str:
if len(expression.expressions) > 1:
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 895588a..0390113 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -11,25 +11,24 @@ from sqlglot.helper import seq_get
class StarRocks(MySQL):
- class Parser(MySQL.Parser): # type: ignore
+ class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
- "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
}
- class Generator(MySQL.Generator): # type: ignore
+ class Generator(MySQL.Generator):
TYPE_MAPPING = {
- **MySQL.Generator.TYPE_MAPPING, # type: ignore
+ **MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
TRANSFORMS = {
- **MySQL.Generator.TRANSFORMS, # type: ignore
+ **MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
@@ -43,4 +42,5 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
+
TRANSFORMS.pop(exp.DateTrunc)
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 51e685b..d5fba17 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -4,41 +4,38 @@ from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
-def _if_sql(self, expression):
- return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END"
-
-
-def _coalesce_sql(self, expression):
- return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
-
-
-def _count_sql(self, expression):
- this = expression.this
- if isinstance(this, exp.Distinct):
- return f"COUNTD({self.expressions(this, flat=True)})"
- return f"COUNT({self.sql(expression, 'this')})"
-
-
class Tableau(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
- exp.If: _if_sql,
- exp.Coalesce: _coalesce_sql,
- exp.Count: _count_sql,
+ **generator.Generator.TRANSFORMS,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def if_sql(self, expression: exp.If) -> str:
+ this = self.sql(expression, "this")
+ true = self.sql(expression, "true")
+ false = self.sql(expression, "false")
+ return f"IF {this} THEN {true} ELSE {false} END"
+
+ def coalesce_sql(self, expression: exp.Coalesce) -> str:
+ return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
+
+ def count_sql(self, expression: exp.Count) -> str:
+ this = expression.this
+ if isinstance(this, exp.Distinct):
+ return f"COUNTD({self.expressions(this, flat=True)})"
+ return f"COUNT({self.sql(expression, 'this')})"
+
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index a79eaeb..9b39178 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -75,12 +75,12 @@ class Teradata(Dialect):
FUNC_TOKENS.remove(TokenType.REPLACE)
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.REPLACE: lambda self: self._parse_create(),
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"RANGE_N": lambda self: self._parse_rangen(),
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
@@ -106,7 +106,7 @@ class Teradata(Dialect):
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
- "from": self._parse_from(),
+ "from": self._parse_from(modifiers=True),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
@@ -135,13 +135,15 @@ class Teradata(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
+ **generator.Generator.PROPERTIES_LOCATION,
+ exp.OnCommitProperty: exp.Properties.Location.POST_INDEX,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION,
+ exp.StabilityProperty: exp.Properties.Location.POST_CREATE,
}
TRANSFORMS = {
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index c7b34fe..af0f78d 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -7,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
- **Presto.Generator.TRANSFORMS, # type: ignore
+ **Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 03de99c..f6ad888 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -16,6 +16,9 @@ from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
@@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
-def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
- def _format_time(args):
+def _format_time_lambda(
+ exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
+) -> t.Callable[[t.List], E]:
+ def _format_time(args: t.List) -> E:
+ assert len(args) == 2
+
return exp_class(
- this=seq_get(args, 1),
+ this=args[1],
format=exp.Literal.string(
format_time(
- seq_get(args, 0).name or (TSQL.time_format if default is True else default),
+ args[0].name,
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
@@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
-def _parse_format(args):
- fmt = seq_get(args, 1)
- number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
+def _parse_format(args: t.List) -> exp.Expression:
+ assert len(args) == 2
+
+ fmt = args[1]
+ number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
+
if number_fmt:
- return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
+ return exp.NumberToStr(this=args[0], format=fmt)
+
return exp.TimeToStr(
- this=seq_get(args, 0),
+ this=args[0],
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@@ -82,7 +93,7 @@ def _parse_format(args):
)
-def _parse_eomonth(args):
+def _parse_eomonth(args: t.List) -> exp.Expression:
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
@@ -96,7 +107,7 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
-def _parse_hashbytes(args):
+def _parse_hashbytes(args: t.List) -> exp.Expression:
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
@@ -110,40 +121,47 @@ def _parse_hashbytes(args):
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
+
return exp.func("HASHBYTES", *args)
-def generate_date_delta_with_unit_sql(self, e):
- func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
- return self.func(func, e.text("unit"), e.expression, e.this)
+def generate_date_delta_with_unit_sql(
+ self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
+) -> str:
+ func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
+ return self.func(func, expression.text("unit"), expression.expression, expression.this)
-def _format_sql(self, e):
+def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
- e.args["format"]
- if isinstance(e, exp.NumberToStr)
- else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
+ expression.args["format"]
+ if isinstance(expression, exp.NumberToStr)
+ else exp.Literal.string(
+ format_time(
+ expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
+ )
+ )
)
- return self.func("FORMAT", e.this, fmt)
+ return self.func("FORMAT", expression.this, fmt)
-def _string_agg_sql(self, e):
- e = e.copy()
+def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+ expression = expression.copy()
- this = e.this
- distinct = e.find(exp.Distinct)
+ this = expression.this
+ distinct = expression.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.pop().expressions[0]
order = ""
- if isinstance(e.this, exp.Order):
- if e.this.this:
- this = e.this.this.pop()
- order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
+ if isinstance(expression.this, exp.Order):
+ if expression.this.this:
+ this = expression.this.this.pop()
+ order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space
- separator = e.args.get("separator") or exp.Literal.string(",")
+ separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
@@ -292,7 +310,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -332,13 +350,13 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
- RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
+ RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
- *parser.Parser.TYPE_TOKENS, # type: ignore
+ *parser.Parser.TYPE_TOKENS,
}
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_command(),
}
@@ -377,7 +395,7 @@ class TSQL(Dialect):
return system_time
- def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
@@ -450,7 +468,7 @@ class TSQL(Dialect):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
@@ -458,7 +476,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@@ -480,7 +498,7 @@ class TSQL(Dialect):
TRANSFORMS.pop(exp.ReturnsProperty)
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 86665e0..c10d640 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -53,7 +53,8 @@ class Keep:
if t.TYPE_CHECKING:
- T = t.TypeVar("T")
+ from sqlglot._typing import T
+
Edit = t.Union[Insert, Remove, Move, Update, Keep]
@@ -240,7 +241,7 @@ class ChangeDistiller:
return matching_set
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
- candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
+ candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
@@ -252,6 +253,7 @@ class ChangeDistiller:
candidate_matchings,
(
-similarity_score,
+ -_parent_similarity_score(source_leaf, target_leaf),
len(candidate_matchings),
source_leaf,
target_leaf,
@@ -261,7 +263,7 @@ class ChangeDistiller:
# Pick best matchings based on the highest score
matching_set = set()
while candidate_matchings:
- _, _, source_leaf, target_leaf = heappop(candidate_matchings)
+ _, _, _, source_leaf, target_leaf = heappop(candidate_matchings)
if (
id(source_leaf) in self._unmatched_source_nodes
and id(target_leaf) in self._unmatched_target_nodes
@@ -327,6 +329,15 @@ def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
return False
+def _parent_similarity_score(
+ source: t.Optional[exp.Expression], target: t.Optional[exp.Expression]
+) -> int:
+ if source is None or target is None or type(source) is not type(target):
+ return 0
+
+ return 1 + _parent_similarity_score(source.parent, target.parent)
+
+
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index a67c155..017d5bc 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -14,9 +14,10 @@ from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
+from sqlglot.helper import dict_depth
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
-from sqlglot.schema import ensure_schema
+from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
logger = logging.getLogger("sqlglot")
@@ -52,10 +53,15 @@ def execute(
tables_ = ensure_tables(tables)
if not schema:
- schema = {
- name: {column: type(table[0][column]).__name__ for column in table.columns}
- for name, table in tables_.mapping.items()
- }
+ schema = {}
+ flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping))
+
+ for keys in flattened_tables:
+ table = nested_get(tables_.mapping, *zip(keys, keys))
+ assert table is not None
+
+ for column in table.columns:
+ nested_set(schema, [*keys, column], type(table[0][column]).__name__)
schema = ensure_schema(schema, dialect=read)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 8f64cce..51cffbd 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -5,6 +5,7 @@ import statistics
from functools import wraps
from sqlglot import exp
+from sqlglot.generator import Generator
from sqlglot.helper import PYTHON_VERSION
@@ -102,6 +103,8 @@ def cast(this, to):
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(this)
+ if to == exp.DataType.Type.BOOLEAN:
+ return bool(this)
if to in exp.DataType.TEXT_TYPES:
return str(this)
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
@@ -119,9 +122,11 @@ def ordered(this, desc, nulls_first):
@null_if_any
def interval(this, unit):
- if unit == "DAY":
- return datetime.timedelta(days=float(this))
- raise NotImplementedError
+ unit = unit.lower()
+ plural = unit + "s"
+ if plural in Generator.TIME_PART_SINGULARS:
+ unit = plural
+ return datetime.timedelta(**{unit: float(this)})
ENV = {
@@ -147,7 +152,9 @@ ENV = {
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
+ "DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
@@ -162,6 +169,7 @@ ENV = {
"LOWER": null_if_any(lambda arg: arg.lower()),
"LT": null_if_any(lambda this, e: this < e),
"LTE": null_if_any(lambda this, e: this <= e),
+ "MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore
"MOD": null_if_any(lambda e, this: e % this),
"MUL": null_if_any(lambda e, this: e * this),
"NEQ": null_if_any(lambda this, e: this != e),
@@ -180,4 +188,5 @@ ENV = {
"CURRENTTIMESTAMP": datetime.datetime.now,
"CURRENTTIME": datetime.datetime.now,
"CURRENTDATE": datetime.date.today,
+ "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index b71cc6a..f114e5c 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -360,11 +360,19 @@ def _ordered_py(self, expression):
def _rename(self, e):
try:
- if "expressions" in e.args:
- this = self.sql(e, "this")
- this = f"{this}, " if this else ""
- return f"{e.key.upper()}({this}{self.expressions(e)})"
- return self.func(e.key, *e.args.values())
+ values = list(e.args.values())
+
+ if len(values) == 1:
+ values = values[0]
+ if not isinstance(values, list):
+ return self.func(e.key, values)
+ return self.func(e.key, *values)
+
+ if isinstance(e, exp.Func) and e.is_var_len_args:
+ *head, tail = values
+ return self.func(e.key, *head, *tail)
+
+ return self.func(e.key, *values)
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex
@@ -413,6 +421,7 @@ class Python(Dialect):
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
+ exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 9e7379d..a4c4e95 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -21,6 +21,7 @@ from collections import deque
from copy import deepcopy
from enum import auto
+from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.helper import (
AutoName,
@@ -28,7 +29,6 @@ from sqlglot.helper import (
ensure_collection,
ensure_list,
seq_get,
- split_num_words,
subclasses,
)
from sqlglot.tokens import Token
@@ -36,8 +36,6 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
-E = t.TypeVar("E", bound="Expression")
-
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@@ -200,11 +198,11 @@ class Expression(metaclass=_Expression):
return self.text("this")
@property
- def alias_or_name(self):
+ def alias_or_name(self) -> str:
return self.alias or self.name
@property
- def output_name(self):
+ def output_name(self) -> str:
"""
Name of the output column if this expression is a selection.
@@ -264,7 +262,7 @@ class Expression(metaclass=_Expression):
if comments:
self.comments.extend(comments)
- def append(self, arg_key, value):
+ def append(self, arg_key: str, value: t.Any) -> None:
"""
Appends value to arg_key if it's a list or sets it as a new list.
@@ -277,7 +275,7 @@ class Expression(metaclass=_Expression):
self.args[arg_key].append(value)
self._set_parent(arg_key, value)
- def set(self, arg_key, value):
+ def set(self, arg_key: str, value: t.Any) -> None:
"""
Sets `arg_key` to `value`.
@@ -288,7 +286,7 @@ class Expression(metaclass=_Expression):
self.args[arg_key] = value
self._set_parent(arg_key, value)
- def _set_parent(self, arg_key, value):
+ def _set_parent(self, arg_key: str, value: t.Any) -> None:
if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
@@ -299,7 +297,7 @@ class Expression(metaclass=_Expression):
v.arg_key = arg_key
@property
- def depth(self):
+ def depth(self) -> int:
"""
Returns the depth of this tree.
"""
@@ -318,26 +316,28 @@ class Expression(metaclass=_Expression):
if hasattr(vs, "parent"):
yield k, vs
- def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
+ def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]:
"""
Returns the first node in this tree which matches at least one of
the specified types.
Args:
expression_types: the expression type(s) to match.
+ bfs: whether to search the AST using the BFS algorithm (DFS is used if false).
Returns:
The node which matches the criteria or None if no such node was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
- def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]:
+ def find_all(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Iterator[E]:
"""
Returns a generator object which visits all nodes in this tree and only
yields those that match at least one of the specified expression types.
Args:
expression_types: the expression type(s) to match.
+ bfs: whether to search the AST using the BFS algorithm (DFS is used if false).
Returns:
The generator object.
@@ -346,7 +346,7 @@ class Expression(metaclass=_Expression):
if isinstance(expression, expression_types):
yield expression
- def find_ancestor(self, *expression_types: t.Type[E]) -> E | None:
+ def find_ancestor(self, *expression_types: t.Type[E]) -> t.Optional[E]:
"""
Returns a nearest parent matching expression_types.
@@ -362,14 +362,14 @@ class Expression(metaclass=_Expression):
return t.cast(E, ancestor)
@property
- def parent_select(self):
+ def parent_select(self) -> t.Optional[Select]:
"""
Returns the parent select statement.
"""
return self.find_ancestor(Select)
@property
- def same_parent(self):
+ def same_parent(self) -> bool:
"""Returns if the parent is the same class as itself."""
return type(self.parent) is self.__class__
@@ -469,10 +469,10 @@ class Expression(metaclass=_Expression):
if not type(node) is self.__class__:
yield node.unnest() if unnest else node
- def __str__(self):
+ def __str__(self) -> str:
return self.sql()
- def __repr__(self):
+ def __repr__(self) -> str:
return self._to_s()
def sql(self, dialect: DialectType = None, **opts) -> str:
@@ -541,6 +541,14 @@ class Expression(metaclass=_Expression):
replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
return new_node
+ @t.overload
+ def replace(self, expression: E) -> E:
+ ...
+
+ @t.overload
+ def replace(self, expression: None) -> None:
+ ...
+
def replace(self, expression):
"""
Swap out this expression with a new expression.
@@ -554,7 +562,7 @@ class Expression(metaclass=_Expression):
'SELECT y FROM tbl'
Args:
- expression (Expression|None): new node
+ expression: new node
Returns:
The new expression or expressions.
@@ -568,7 +576,7 @@ class Expression(metaclass=_Expression):
replace_children(parent, lambda child: expression if child is self else child)
return expression
- def pop(self):
+ def pop(self: E) -> E:
"""
Remove this expression from its AST.
@@ -578,7 +586,7 @@ class Expression(metaclass=_Expression):
self.replace(None)
return self
- def assert_is(self, type_):
+ def assert_is(self, type_: t.Type[E]) -> E:
"""
Assert that this `Expression` is an instance of `type_`.
@@ -656,7 +664,13 @@ ExpOrStr = t.Union[str, Expression]
class Condition(Expression):
- def and_(self, *expressions, dialect=None, copy=True, **opts):
+ def and_(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Condition:
"""
AND this condition with one or multiple expressions.
@@ -665,18 +679,24 @@ class Condition(Expression):
'x = 1 AND y = 1'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy the involved expressions (only applies to Expressions).
+ opts: other options to use to parse the input expressions.
Returns:
- And: the new condition.
+ The new And condition.
"""
return and_(self, *expressions, dialect=dialect, copy=copy, **opts)
- def or_(self, *expressions, dialect=None, copy=True, **opts):
+ def or_(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Condition:
"""
OR this condition with one or multiple expressions.
@@ -685,18 +705,18 @@ class Condition(Expression):
'x = 1 OR y = 1'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy the involved expressions (only applies to Expressions).
+ opts: other options to use to parse the input expressions.
Returns:
- Or: the new condition.
+ The new Or condition.
"""
return or_(self, *expressions, dialect=dialect, copy=copy, **opts)
- def not_(self, copy=True):
+ def not_(self, copy: bool = True):
"""
Wrap this condition with NOT.
@@ -705,14 +725,24 @@ class Condition(Expression):
'NOT x = 1'
Args:
- copy (bool): whether or not to copy this object.
+ copy: whether or not to copy this object.
Returns:
- Not: the new condition.
+ The new Not instance.
"""
return not_(self, copy=copy)
- def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
+ def as_(
+ self,
+ alias: str | Identifier,
+ quoted: t.Optional[bool] = None,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Alias:
+ return alias_(self, alias, quoted=quoted, dialect=dialect, copy=copy, **opts)
+
+ def _binop(self, klass: t.Type[E], other: t.Any, reverse: bool = False) -> E:
this = self.copy()
other = convert(other, copy=True)
if not isinstance(this, klass) and not isinstance(other, klass):
@@ -728,7 +758,7 @@ class Condition(Expression):
)
def isin(
- self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
+ self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts
) -> In:
return In(
this=_maybe_copy(self, copy),
@@ -736,92 +766,95 @@ class Condition(Expression):
query=maybe_parse(query, copy=copy, **opts) if query else None,
)
- def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
+ def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
return Between(
this=_maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)
+ def is_(self, other: ExpOrStr) -> Is:
+ return self._binop(Is, other)
+
def like(self, other: ExpOrStr) -> Like:
return self._binop(Like, other)
def ilike(self, other: ExpOrStr) -> ILike:
return self._binop(ILike, other)
- def eq(self, other: ExpOrStr) -> EQ:
+ def eq(self, other: t.Any) -> EQ:
return self._binop(EQ, other)
- def neq(self, other: ExpOrStr) -> NEQ:
+ def neq(self, other: t.Any) -> NEQ:
return self._binop(NEQ, other)
def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)
- def __lt__(self, other: ExpOrStr) -> LT:
+ def __lt__(self, other: t.Any) -> LT:
return self._binop(LT, other)
- def __le__(self, other: ExpOrStr) -> LTE:
+ def __le__(self, other: t.Any) -> LTE:
return self._binop(LTE, other)
- def __gt__(self, other: ExpOrStr) -> GT:
+ def __gt__(self, other: t.Any) -> GT:
return self._binop(GT, other)
- def __ge__(self, other: ExpOrStr) -> GTE:
+ def __ge__(self, other: t.Any) -> GTE:
return self._binop(GTE, other)
- def __add__(self, other: ExpOrStr) -> Add:
+ def __add__(self, other: t.Any) -> Add:
return self._binop(Add, other)
- def __radd__(self, other: ExpOrStr) -> Add:
+ def __radd__(self, other: t.Any) -> Add:
return self._binop(Add, other, reverse=True)
- def __sub__(self, other: ExpOrStr) -> Sub:
+ def __sub__(self, other: t.Any) -> Sub:
return self._binop(Sub, other)
- def __rsub__(self, other: ExpOrStr) -> Sub:
+ def __rsub__(self, other: t.Any) -> Sub:
return self._binop(Sub, other, reverse=True)
- def __mul__(self, other: ExpOrStr) -> Mul:
+ def __mul__(self, other: t.Any) -> Mul:
return self._binop(Mul, other)
- def __rmul__(self, other: ExpOrStr) -> Mul:
+ def __rmul__(self, other: t.Any) -> Mul:
return self._binop(Mul, other, reverse=True)
- def __truediv__(self, other: ExpOrStr) -> Div:
+ def __truediv__(self, other: t.Any) -> Div:
return self._binop(Div, other)
- def __rtruediv__(self, other: ExpOrStr) -> Div:
+ def __rtruediv__(self, other: t.Any) -> Div:
return self._binop(Div, other, reverse=True)
- def __floordiv__(self, other: ExpOrStr) -> IntDiv:
+ def __floordiv__(self, other: t.Any) -> IntDiv:
return self._binop(IntDiv, other)
- def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
+ def __rfloordiv__(self, other: t.Any) -> IntDiv:
return self._binop(IntDiv, other, reverse=True)
- def __mod__(self, other: ExpOrStr) -> Mod:
+ def __mod__(self, other: t.Any) -> Mod:
return self._binop(Mod, other)
- def __rmod__(self, other: ExpOrStr) -> Mod:
+ def __rmod__(self, other: t.Any) -> Mod:
return self._binop(Mod, other, reverse=True)
- def __pow__(self, other: ExpOrStr) -> Pow:
+ def __pow__(self, other: t.Any) -> Pow:
return self._binop(Pow, other)
- def __rpow__(self, other: ExpOrStr) -> Pow:
+ def __rpow__(self, other: t.Any) -> Pow:
return self._binop(Pow, other, reverse=True)
- def __and__(self, other: ExpOrStr) -> And:
+ def __and__(self, other: t.Any) -> And:
return self._binop(And, other)
- def __rand__(self, other: ExpOrStr) -> And:
+ def __rand__(self, other: t.Any) -> And:
return self._binop(And, other, reverse=True)
- def __or__(self, other: ExpOrStr) -> Or:
+ def __or__(self, other: t.Any) -> Or:
return self._binop(Or, other)
- def __ror__(self, other: ExpOrStr) -> Or:
+ def __ror__(self, other: t.Any) -> Or:
return self._binop(Or, other, reverse=True)
def __neg__(self) -> Neg:
@@ -837,12 +870,11 @@ class Predicate(Condition):
class DerivedTable(Expression):
@property
- def alias_column_names(self):
+ def alias_column_names(self) -> t.List[str]:
table_alias = self.args.get("alias")
if not table_alias:
return []
- column_list = table_alias.assert_is(TableAlias).args.get("columns") or []
- return [c.name for c in column_list]
+ return [c.name for c in table_alias.args.get("columns") or []]
@property
def selects(self):
@@ -854,7 +886,9 @@ class DerivedTable(Expression):
class Unionable(Expression):
- def union(self, expression, distinct=True, dialect=None, **opts):
+ def union(
+ self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+ ) -> Unionable:
"""
Builds a UNION expression.
@@ -864,17 +898,20 @@ class Unionable(Expression):
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
- expression (str | Expression): the SQL code string.
+ expression: the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
- distinct (bool): set the DISTINCT flag if and only if this is true.
- dialect (str): the dialect used to parse the input expression.
- opts (kwargs): other options to use to parse the input expressions.
+ distinct: set the DISTINCT flag if and only if this is true.
+ dialect: the dialect used to parse the input expression.
+ opts: other options to use to parse the input expressions.
+
Returns:
- Union: the Union expression.
+ The new Union expression.
"""
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
- def intersect(self, expression, distinct=True, dialect=None, **opts):
+ def intersect(
+ self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+ ) -> Unionable:
"""
Builds an INTERSECT expression.
@@ -884,17 +921,20 @@ class Unionable(Expression):
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
- expression (str | Expression): the SQL code string.
+ expression: the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
- distinct (bool): set the DISTINCT flag if and only if this is true.
- dialect (str): the dialect used to parse the input expression.
- opts (kwargs): other options to use to parse the input expressions.
+ distinct: set the DISTINCT flag if and only if this is true.
+ dialect: the dialect used to parse the input expression.
+ opts: other options to use to parse the input expressions.
+
Returns:
- Intersect: the Intersect expression
+ The new Intersect expression.
"""
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
- def except_(self, expression, distinct=True, dialect=None, **opts):
+ def except_(
+ self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+ ) -> Unionable:
"""
Builds an EXCEPT expression.
@@ -904,13 +944,14 @@ class Unionable(Expression):
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
- expression (str | Expression): the SQL code string.
+ expression: the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
- distinct (bool): set the DISTINCT flag if and only if this is true.
- dialect (str): the dialect used to parse the input expression.
- opts (kwargs): other options to use to parse the input expressions.
+ distinct: set the DISTINCT flag if and only if this is true.
+ dialect: the dialect used to parse the input expression.
+ opts: other options to use to parse the input expressions.
+
Returns:
- Except: the Except expression
+ The new Except expression.
"""
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
@@ -949,6 +990,17 @@ class Create(Expression):
"indexes": False,
"no_schema_binding": False,
"begin": False,
+ "clone": False,
+ }
+
+
+# https://docs.snowflake.com/en/sql-reference/sql/create-clone
+class Clone(Expression):
+ arg_types = {
+ "this": True,
+ "when": False,
+ "kind": False,
+ "expression": False,
}
@@ -1038,6 +1090,10 @@ class ByteString(Condition):
pass
+class RawString(Condition):
+ pass
+
+
class Column(Condition):
arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False}
@@ -1060,7 +1116,11 @@ class Column(Condition):
@property
def parts(self) -> t.List[Identifier]:
"""Return the parts of a column in order catalog, db, table, name."""
- return [part for part in reversed(list(self.args.values())) if part]
+ return [
+ t.cast(Identifier, self.args[part])
+ for part in ("catalog", "db", "table", "this")
+ if self.args.get(part)
+ ]
def to_dot(self) -> Dot:
"""Converts the column into a dot expression."""
@@ -1116,6 +1176,27 @@ class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
+# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
+class MergeTreeTTLAction(Expression):
+ arg_types = {
+ "this": True,
+ "delete": False,
+ "recompress": False,
+ "to_disk": False,
+ "to_volume": False,
+ }
+
+
+# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
+class MergeTreeTTL(Expression):
+ arg_types = {
+ "expressions": True,
+ "where": False,
+ "group": False,
+ "aggregates": False,
+ }
+
+
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@@ -1172,6 +1253,8 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {
"this": False,
+ "expression": False,
+ "on_null": False,
"start": False,
"increment": False,
"minvalue": False,
@@ -1202,7 +1285,7 @@ class TitleColumnConstraint(ColumnConstraintKind):
class UniqueColumnConstraint(ColumnConstraintKind):
- arg_types: t.Dict[str, t.Any] = {}
+ arg_types = {"this": False}
class UppercaseColumnConstraint(ColumnConstraintKind):
@@ -1255,7 +1338,7 @@ class Delete(Expression):
def where(
self,
- *expressions: ExpOrStr,
+ *expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@@ -1367,10 +1450,6 @@ class PrimaryKey(Expression):
arg_types = {"expressions": True, "options": False}
-class Unique(Expression):
- arg_types = {"expressions": True}
-
-
# https://www.postgresql.org/docs/9.1/sql-selectinto.html
# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples
class Into(Expression):
@@ -1378,7 +1457,13 @@ class Into(Expression):
class From(Expression):
- arg_types = {"expressions": True}
+ @property
+ def name(self) -> str:
+ return self.this.name
+
+ @property
+ def alias_or_name(self) -> str:
+ return self.this.alias_or_name
class Having(Expression):
@@ -1397,7 +1482,7 @@ class Identifier(Expression):
arg_types = {"this": True, "quoted": False}
@property
- def quoted(self):
+ def quoted(self) -> bool:
return bool(self.args.get("quoted"))
@property
@@ -1407,7 +1492,7 @@ class Identifier(Expression):
return self.this.lower()
@property
- def output_name(self):
+ def output_name(self) -> str:
return self.name
@@ -1420,6 +1505,7 @@ class Index(Expression):
"unique": False,
"primary": False,
"amp": False, # teradata
+ "partition_by": False, # teradata
}
@@ -1436,6 +1522,42 @@ class Insert(Expression):
"alternative": False,
}
+ def with_(
+ self,
+ alias: ExpOrStr,
+ as_: ExpOrStr,
+ recursive: t.Optional[bool] = None,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Insert:
+ """
+ Append to or set the common table expressions.
+
+ Example:
+ >>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql()
+ 'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte'
+
+ Args:
+ alias: the SQL code string to parse as the table name.
+ If an `Expression` instance is passed, this is used as-is.
+ as_: the SQL code string to parse as the table expression.
+ If an `Expression` instance is passed, it will be used as-is.
+ recursive: set the RECURSIVE part of the expression. Defaults to `False`.
+ append: if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ The modified expression.
+ """
+ return _apply_cte_builder(
+ self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
+ )
+
class OnConflict(Expression):
arg_types = {
@@ -1492,6 +1614,7 @@ class Group(Expression):
"grouping_sets": False,
"cube": False,
"rollup": False,
+ "totals": False,
}
@@ -1519,7 +1642,7 @@ class Literal(Condition):
return cls(this=str(string), is_string=True)
@property
- def output_name(self):
+ def output_name(self) -> str:
return self.name
@@ -1531,26 +1654,34 @@ class Join(Expression):
"kind": False,
"using": False,
"natural": False,
+ "global": False,
"hint": False,
}
@property
- def kind(self):
+ def kind(self) -> str:
return self.text("kind").upper()
@property
- def side(self):
+ def side(self) -> str:
return self.text("side").upper()
@property
- def hint(self):
+ def hint(self) -> str:
return self.text("hint").upper()
@property
- def alias_or_name(self):
+ def alias_or_name(self) -> str:
return self.this.alias_or_name
- def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def on(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Join:
"""
Append to or set the ON expressions.
@@ -1560,17 +1691,17 @@ class Join(Expression):
'JOIN x ON y = 1'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
- append (bool): if `True`, AND the new expressions to any existing expression.
+ append: if `True`, AND the new expressions to any existing expression.
Otherwise, this resets the expression.
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Join: the modified join expression.
+ The modified Join expression.
"""
join = _apply_conjunction_builder(
*expressions,
@@ -1587,7 +1718,14 @@ class Join(Expression):
return join
- def using(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def using(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Join:
"""
Append to or set the USING expressions.
@@ -1597,16 +1735,16 @@ class Join(Expression):
'JOIN x USING (foo, bla)'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
- append (bool): if `True`, concatenate the new expressions to the existing "using" list.
+ append: if `True`, concatenate the new expressions to the existing "using" list.
Otherwise, this resets the expression.
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Join: the modified join expression.
+ The modified Join expression.
"""
join = _apply_list_builder(
*expressions,
@@ -1677,10 +1815,6 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
-class AfterJournalProperty(Property):
- arg_types = {"no": True, "dual": False, "local": False}
-
-
class AlgorithmProperty(Property):
arg_types = {"this": True}
@@ -1706,7 +1840,13 @@ class CollateProperty(Property):
class DataBlocksizeProperty(Property):
- arg_types = {"size": False, "units": False, "min": False, "default": False}
+ arg_types = {
+ "size": False,
+ "units": False,
+ "minimum": False,
+ "maximum": False,
+ "default": False,
+ }
class DefinerProperty(Property):
@@ -1760,7 +1900,13 @@ class IsolatedLoadingProperty(Property):
class JournalProperty(Property):
- arg_types = {"no": True, "dual": False, "before": False}
+ arg_types = {
+ "no": False,
+ "dual": False,
+ "before": False,
+ "local": False,
+ "after": False,
+ }
class LanguageProperty(Property):
@@ -1798,11 +1944,11 @@ class MergeBlockRatioProperty(Property):
class NoPrimaryIndexProperty(Property):
- arg_types = {"this": False}
+ arg_types = {}
class OnCommitProperty(Property):
- arg_type = {"this": False}
+ arg_type = {"delete": False}
class PartitionedByProperty(Property):
@@ -1846,6 +1992,10 @@ class SetProperty(Property):
arg_types = {"multi": True}
+class SettingsProperty(Property):
+ arg_types = {"expressions": True}
+
+
class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
@@ -1858,12 +2008,8 @@ class StabilityProperty(Property):
arg_types = {"this": True}
-class TableFormatProperty(Property):
- arg_types = {"this": True}
-
-
class TemporaryProperty(Property):
- arg_types = {"global_": True}
+ arg_types = {}
class TransientProperty(Property):
@@ -1903,7 +2049,6 @@ class Properties(Expression):
"RETURNS": ReturnsProperty,
"ROW_FORMAT": RowFormatProperty,
"SORTKEY": SortKeyProperty,
- "TABLE_FORMAT": TableFormatProperty,
}
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
@@ -1932,7 +2077,7 @@ class Properties(Expression):
UNSUPPORTED = auto()
@classmethod
- def from_dict(cls, properties_dict) -> Properties:
+ def from_dict(cls, properties_dict: t.Dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.NAME_TO_PROPERTY.get(key.upper())
@@ -1961,7 +2106,7 @@ class Tuple(Expression):
arg_types = {"expressions": False}
def isin(
- self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
+ self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts
) -> In:
return In(
this=_maybe_copy(self, copy),
@@ -1971,7 +2116,7 @@ class Tuple(Expression):
class Subqueryable(Unionable):
- def subquery(self, alias=None, copy=True) -> Subquery:
+ def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery:
"""
Convert this expression to an aliased expression that can be used as a Subquery.
@@ -1988,12 +2133,14 @@ class Subqueryable(Unionable):
Alias: the subquery
"""
instance = _maybe_copy(self, copy)
- return Subquery(
- this=instance,
- alias=TableAlias(this=to_identifier(alias)) if alias else None,
- )
+ if not isinstance(alias, Expression):
+ alias = TableAlias(this=to_identifier(alias)) if alias else None
+
+ return Subquery(this=instance, alias=alias)
- def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ def limit(
+ self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
+ ) -> Select:
raise NotImplementedError
@property
@@ -2013,14 +2160,14 @@ class Subqueryable(Unionable):
def with_(
self,
- alias,
- as_,
- recursive=None,
- append=True,
- dialect=None,
- copy=True,
+ alias: ExpOrStr,
+ as_: ExpOrStr,
+ recursive: t.Optional[bool] = None,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
**opts,
- ):
+ ) -> Subqueryable:
"""
Append to or set the common table expressions.
@@ -2029,43 +2176,22 @@ class Subqueryable(Unionable):
'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
Args:
- alias (str | Expression): the SQL code string to parse as the table name.
+ alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
- as_ (str | Expression): the SQL code string to parse as the table expression.
+ as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
- recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`.
- append (bool): if `True`, add to any existing expressions.
+ recursive: set the RECURSIVE part of the expression. Defaults to `False`.
+ append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified expression.
"""
- alias_expression = maybe_parse(
- alias,
- dialect=dialect,
- into=TableAlias,
- **opts,
- )
- as_expression = maybe_parse(
- as_,
- dialect=dialect,
- **opts,
- )
- cte = CTE(
- this=as_expression,
- alias=alias_expression,
- )
- return _apply_child_list_builder(
- cte,
- instance=self,
- arg="with",
- append=append,
- copy=copy,
- into=With,
- properties={"recursive": recursive or False},
+ return _apply_cte_builder(
+ self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
)
@@ -2085,8 +2211,10 @@ QUERY_MODIFIERS = {
"order": False,
"limit": False,
"offset": False,
- "lock": False,
+ "locks": False,
"sample": False,
+ "settings": False,
+ "format": False,
}
@@ -2111,6 +2239,15 @@ class Table(Expression):
def catalog(self) -> str:
return self.text("catalog")
+ @property
+ def parts(self) -> t.List[Identifier]:
+ """Return the parts of a table in order catalog, db, table."""
+ return [
+ t.cast(Identifier, self.args[part])
+ for part in ("catalog", "db", "this")
+ if self.args.get(part)
+ ]
+
# See the TSQL "Querying data in a system-versioned temporal table" page
class SystemTime(Expression):
@@ -2130,7 +2267,9 @@ class Union(Subqueryable):
**QUERY_MODIFIERS,
}
- def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ def limit(
+ self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
+ ) -> Select:
"""
Set the LIMIT expression.
@@ -2139,16 +2278,16 @@ class Union(Subqueryable):
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
Args:
- expression (str | int | Expression): the SQL code string to parse.
+ expression: the SQL code string to parse.
This can also be an integer.
If a `Limit` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: The limited subqueryable.
+ The limited subqueryable.
"""
return (
select("*")
@@ -2158,7 +2297,7 @@ class Union(Subqueryable):
def select(
self,
- *expressions: ExpOrStr,
+ *expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@@ -2255,10 +2394,10 @@ class Schema(Expression):
arg_types = {"this": False, "expressions": False}
-# Used to represent the FOR UPDATE and FOR SHARE locking read types.
-# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
+# https://dev.mysql.com/doc/refman/8.0/en/select.html
+# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/SELECT.html
class Lock(Expression):
- arg_types = {"update": True}
+ arg_types = {"update": True, "expressions": False, "wait": False}
class Select(Subqueryable):
@@ -2275,7 +2414,9 @@ class Select(Subqueryable):
**QUERY_MODIFIERS,
}
- def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def from_(
+ self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts
+ ) -> Select:
"""
Set the FROM expression.
@@ -2284,31 +2425,35 @@ class Select(Subqueryable):
'SELECT x FROM tbl'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ expression : the SQL code strings to parse.
If a `From` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `From`.
- append (bool): if `True`, add to any existing expressions.
- Otherwise, this flattens all the `From` expression into a single expression.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
- return _apply_child_list_builder(
- *expressions,
+ return _apply_builder(
+ expression=expression,
instance=self,
arg="from",
- append=append,
- copy=copy,
- prefix="FROM",
into=From,
+ prefix="FROM",
dialect=dialect,
+ copy=copy,
**opts,
)
- def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def group_by(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Set the GROUP BY expression.
@@ -2317,21 +2462,22 @@ class Select(Subqueryable):
'SELECT x, COUNT(1) FROM tbl GROUP BY x'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Group`.
If nothing is passed in then a group by is not applied to the expression
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this flattens all the `Group` expression into a single expression.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
if not expressions:
return self if not copy else self.copy()
+
return _apply_child_list_builder(
*expressions,
instance=self,
@@ -2344,7 +2490,14 @@ class Select(Subqueryable):
**opts,
)
- def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def order_by(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Set the ORDER BY expression.
@@ -2353,17 +2506,17 @@ class Select(Subqueryable):
'SELECT x FROM tbl ORDER BY x DESC'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Order`.
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this flattens all the `Order` expression into a single expression.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_child_list_builder(
*expressions,
@@ -2377,26 +2530,33 @@ class Select(Subqueryable):
**opts,
)
- def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def sort_by(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Set the SORT BY expression.
Example:
- >>> Select().from_("tbl").select("x").sort_by("x DESC").sql()
+ >>> Select().from_("tbl").select("x").sort_by("x DESC").sql(dialect="hive")
'SELECT x FROM tbl SORT BY x DESC'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `SORT`.
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this flattens all the `Order` expression into a single expression.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_child_list_builder(
*expressions,
@@ -2410,26 +2570,33 @@ class Select(Subqueryable):
**opts,
)
- def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def cluster_by(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Set the CLUSTER BY expression.
Example:
- >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql()
+ >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql(dialect="hive")
'SELECT x FROM tbl CLUSTER BY x DESC'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Cluster`.
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this flattens all the `Order` expression into a single expression.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_child_list_builder(
*expressions,
@@ -2443,7 +2610,9 @@ class Select(Subqueryable):
**opts,
)
- def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ def limit(
+ self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
+ ) -> Select:
"""
Set the LIMIT expression.
@@ -2452,13 +2621,13 @@ class Select(Subqueryable):
'SELECT x FROM tbl LIMIT 10'
Args:
- expression (str | int | Expression): the SQL code string to parse.
+ expression: the SQL code string to parse.
This can also be an integer.
If a `Limit` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
@@ -2474,7 +2643,9 @@ class Select(Subqueryable):
**opts,
)
- def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
+ def offset(
+ self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
+ ) -> Select:
"""
Set the OFFSET expression.
@@ -2483,16 +2654,16 @@ class Select(Subqueryable):
'SELECT x FROM tbl OFFSET 10'
Args:
- expression (str | int | Expression): the SQL code string to parse.
+ expression: the SQL code string to parse.
This can also be an integer.
If a `Offset` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Offset`.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_builder(
expression=expression,
@@ -2507,7 +2678,7 @@ class Select(Subqueryable):
def select(
self,
- *expressions: ExpOrStr,
+ *expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@@ -2530,7 +2701,7 @@ class Select(Subqueryable):
opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_list_builder(
*expressions,
@@ -2542,7 +2713,14 @@ class Select(Subqueryable):
**opts,
)
- def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def lateral(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Append to or set the LATERAL expressions.
@@ -2551,16 +2729,16 @@ class Select(Subqueryable):
'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_list_builder(
*expressions,
@@ -2576,14 +2754,14 @@ class Select(Subqueryable):
def join(
self,
- expression,
- on=None,
- using=None,
- append=True,
- join_type=None,
- join_alias=None,
- dialect=None,
- copy=True,
+ expression: ExpOrStr,
+ on: t.Optional[ExpOrStr] = None,
+ using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None,
+ append: bool = True,
+ join_type: t.Optional[str] = None,
+ join_alias: t.Optional[Identifier | str] = None,
+ dialect: DialectType = None,
+ copy: bool = True,
**opts,
) -> Select:
"""
@@ -2602,18 +2780,19 @@ class Select(Subqueryable):
'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y'
Args:
- expression (str | Expression): the SQL code string to parse.
+ expression: the SQL code string to parse.
If an `Expression` instance is passed, it will be used as-is.
- on (str | Expression): optionally specify the join "on" criteria as a SQL string.
+ on: optionally specify the join "on" criteria as a SQL string.
If an `Expression` instance is passed, it will be used as-is.
- using (str | Expression): optionally specify the join "using" criteria as a SQL string.
+ using: optionally specify the join "using" criteria as a SQL string.
If an `Expression` instance is passed, it will be used as-is.
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
- join_type (str): If set, alter the parsed join type
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ join_type: if set, alter the parsed join type.
+ join_alias: an optional alias for the joined source.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
@@ -2621,9 +2800,9 @@ class Select(Subqueryable):
parse_args = {"dialect": dialect, **opts}
try:
- expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args)
+ expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore
except ParseError:
- expression = maybe_parse(expression, into=(Join, Expression), **parse_args)
+ expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore
join = expression if isinstance(expression, Join) else Join(this=expression)
@@ -2645,12 +2824,12 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
- on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
+ on = and_(*ensure_list(on), dialect=dialect, copy=copy, **opts)
join.set("on", on)
if using:
join = _apply_list_builder(
- *ensure_collection(using),
+ *ensure_list(using),
instance=join,
arg="using",
append=append,
@@ -2660,6 +2839,7 @@ class Select(Subqueryable):
if join_alias:
join.set("this", alias_(join.this, join_alias, table=True))
+
return _apply_list_builder(
join,
instance=self,
@@ -2669,7 +2849,14 @@ class Select(Subqueryable):
**opts,
)
- def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def where(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Append to or set the WHERE expressions.
@@ -2678,14 +2865,14 @@ class Select(Subqueryable):
"SELECT x FROM tbl WHERE x = 'a' OR x < 'b'"
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
- append (bool): if `True`, AND the new expressions to any existing expression.
+ append: if `True`, AND the new expressions to any existing expression.
Otherwise, this resets the expression.
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
@@ -2701,7 +2888,14 @@ class Select(Subqueryable):
**opts,
)
- def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def having(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Append to or set the HAVING expressions.
@@ -2710,17 +2904,17 @@ class Select(Subqueryable):
'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
- append (bool): if `True`, AND the new expressions to any existing expression.
+ append: if `True`, AND the new expressions to any existing expression.
Otherwise, this resets the expression.
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
- Select: the modified expression.
+ The modified Select expression.
"""
return _apply_conjunction_builder(
*expressions,
@@ -2733,7 +2927,14 @@ class Select(Subqueryable):
**opts,
)
- def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def window(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
return _apply_list_builder(
*expressions,
instance=self,
@@ -2745,7 +2946,14 @@ class Select(Subqueryable):
**opts,
)
- def qualify(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def qualify(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
return _apply_conjunction_builder(
*expressions,
instance=self,
@@ -2757,7 +2965,9 @@ class Select(Subqueryable):
**opts,
)
- def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select:
+ def distinct(
+ self, *ons: t.Optional[ExpOrStr], distinct: bool = True, copy: bool = True
+ ) -> Select:
"""
Set the OFFSET expression.
@@ -2774,11 +2984,18 @@ class Select(Subqueryable):
Select: the modified expression.
"""
instance = _maybe_copy(self, copy)
- on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None
+ on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None
instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
- def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
+ def ctas(
+ self,
+ table: ExpOrStr,
+ properties: t.Optional[t.Dict] = None,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Create:
"""
Convert this expression to a CREATE TABLE AS statement.
@@ -2787,15 +3004,15 @@ class Select(Subqueryable):
'CREATE TABLE x AS SELECT * FROM tbl'
Args:
- table (str | Expression): the SQL code string to parse as the table name.
+ table: the SQL code string to parse as the table name.
If another `Expression` instance is passed, it will be used as-is.
- properties (dict): an optional mapping of table properties
- dialect (str): the dialect used to parse the input table.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input table.
+ properties: an optional mapping of table properties
+ dialect: the dialect used to parse the input table.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input table.
Returns:
- Create: the CREATE TABLE AS expression
+ The new Create expression.
"""
instance = _maybe_copy(self, copy)
table_expression = maybe_parse(
@@ -2835,7 +3052,7 @@ class Select(Subqueryable):
"""
inst = _maybe_copy(self, copy)
- inst.set("lock", Lock(update=update))
+ inst.set("locks", [Lock(update=update)])
return inst
@@ -2874,7 +3091,7 @@ class Subquery(DerivedTable, Unionable):
return self.this.is_star
@property
- def output_name(self):
+ def output_name(self) -> str:
return self.alias
@@ -2903,13 +3120,17 @@ class Tag(Expression):
}
+# Represents both the standard SQL PIVOT operator and DuckDB's "simplified" PIVOT syntax
+# https://duckdb.org/docs/sql/statements/pivot
class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
"expressions": True,
- "field": True,
- "unpivot": True,
+ "field": False,
+ "unpivot": False,
+ "using": False,
+ "group": False,
"columns": False,
}
@@ -2948,7 +3169,7 @@ class Star(Expression):
return "*"
@property
- def output_name(self):
+ def output_name(self) -> str:
return self.name
@@ -2961,7 +3182,7 @@ class SessionParameter(Expression):
class Placeholder(Expression):
- arg_types = {"this": False}
+ arg_types = {"this": False, "kind": False}
class Null(Condition):
@@ -2976,6 +3197,10 @@ class Boolean(Condition):
pass
+class DataTypeSize(Expression):
+ arg_types = {"this": True, "expression": False}
+
+
class DataType(Expression):
arg_types = {
"this": True,
@@ -2986,68 +3211,69 @@ class DataType(Expression):
}
class Type(AutoName):
- CHAR = auto()
- NCHAR = auto()
- VARCHAR = auto()
- NVARCHAR = auto()
- TEXT = auto()
- MEDIUMTEXT = auto()
- LONGTEXT = auto()
- MEDIUMBLOB = auto()
- LONGBLOB = auto()
- BINARY = auto()
- VARBINARY = auto()
- INT = auto()
- UINT = auto()
- TINYINT = auto()
- UTINYINT = auto()
- SMALLINT = auto()
- USMALLINT = auto()
- BIGINT = auto()
- UBIGINT = auto()
- INT128 = auto()
- UINT128 = auto()
- INT256 = auto()
- UINT256 = auto()
- FLOAT = auto()
- DOUBLE = auto()
- DECIMAL = auto()
+ ARRAY = auto()
BIGDECIMAL = auto()
+ BIGINT = auto()
+ BIGSERIAL = auto()
+ BINARY = auto()
BIT = auto()
BOOLEAN = auto()
- JSON = auto()
- JSONB = auto()
- INTERVAL = auto()
- TIME = auto()
- TIMESTAMP = auto()
- TIMESTAMPTZ = auto()
- TIMESTAMPLTZ = auto()
+ CHAR = auto()
DATE = auto()
DATETIME = auto()
- ARRAY = auto()
- MAP = auto()
- UUID = auto()
+ DATETIME64 = auto()
+ DECIMAL = auto()
+ DOUBLE = auto()
+ FLOAT = auto()
GEOGRAPHY = auto()
GEOMETRY = auto()
- STRUCT = auto()
- NULLABLE = auto()
HLLSKETCH = auto()
HSTORE = auto()
- SUPER = auto()
- SERIAL = auto()
- SMALLSERIAL = auto()
- BIGSERIAL = auto()
- XML = auto()
- UNIQUEIDENTIFIER = auto()
- MONEY = auto()
- SMALLMONEY = auto()
- ROWVERSION = auto()
IMAGE = auto()
- VARIANT = auto()
- OBJECT = auto()
INET = auto()
+ INT = auto()
+ INT128 = auto()
+ INT256 = auto()
+ INTERVAL = auto()
+ JSON = auto()
+ JSONB = auto()
+ LONGBLOB = auto()
+ LONGTEXT = auto()
+ MAP = auto()
+ MEDIUMBLOB = auto()
+ MEDIUMTEXT = auto()
+ MONEY = auto()
+ NCHAR = auto()
NULL = auto()
+ NULLABLE = auto()
+ NVARCHAR = auto()
+ OBJECT = auto()
+ ROWVERSION = auto()
+ SERIAL = auto()
+ SMALLINT = auto()
+ SMALLMONEY = auto()
+ SMALLSERIAL = auto()
+ STRUCT = auto()
+ SUPER = auto()
+ TEXT = auto()
+ TIME = auto()
+ TIMESTAMP = auto()
+ TIMESTAMPTZ = auto()
+ TIMESTAMPLTZ = auto()
+ TINYINT = auto()
+ UBIGINT = auto()
+ UINT = auto()
+ USMALLINT = auto()
+ UTINYINT = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
+ UINT128 = auto()
+ UINT256 = auto()
+ UNIQUEIDENTIFIER = auto()
+ UUID = auto()
+ VARBINARY = auto()
+ VARCHAR = auto()
+ VARIANT = auto()
+ XML = auto()
TEXT_TYPES = {
Type.CHAR,
@@ -3079,6 +3305,7 @@ class DataType(Expression):
Type.TIMESTAMPLTZ,
Type.DATE,
Type.DATETIME,
+ Type.DATETIME64,
}
@classmethod
@@ -3092,6 +3319,7 @@ class DataType(Expression):
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
@@ -3100,6 +3328,7 @@ class DataType(Expression):
return dtype
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
+
return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, dtype: DataType.Type) -> bool:
@@ -3361,7 +3590,7 @@ class Alias(Expression):
arg_types = {"this": True, "alias": False}
@property
- def output_name(self):
+ def output_name(self) -> str:
return self.alias
@@ -3411,12 +3640,17 @@ class TimeUnit(Expression):
args["unit"] = Var(this=unit.name)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
+
super().__init__(**args)
class Interval(TimeUnit):
arg_types = {"this": False, "unit": False}
+ @property
+ def unit(self) -> t.Optional[Var]:
+ return self.args.get("unit")
+
class IgnoreNulls(Expression):
pass
@@ -3480,6 +3714,10 @@ class AggFunc(Func):
pass
+class ParameterizedAgg(AggFunc):
+ arg_types = {"this": True, "expressions": True, "params": True}
+
+
class Abs(Func):
pass
@@ -3498,6 +3736,7 @@ class Hll(AggFunc):
class ApproxDistinct(AggFunc):
arg_types = {"this": True, "accuracy": False}
+ _sql_names = ["APPROX_DISTINCT", "APPROX_COUNT_DISTINCT"]
class Array(Func):
@@ -3600,17 +3839,21 @@ class Cast(Func):
return self.this.name
@property
- def to(self):
+ def to(self) -> DataType:
return self.args["to"]
@property
- def output_name(self):
+ def output_name(self) -> str:
return self.name
def is_type(self, dtype: DataType.Type) -> bool:
return self.to.is_type(dtype)
+class CastToStrType(Func):
+ arg_types = {"this": True, "expression": True}
+
+
class Collate(Binary):
pass
@@ -3796,10 +4039,6 @@ class Explode(Func):
pass
-class ExponentialTimeDecayedAvg(AggFunc):
- arg_types = {"this": True, "time": False, "decay": False}
-
-
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@@ -3821,18 +4060,10 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
-class GroupUniqArray(AggFunc):
- arg_types = {"this": True, "size": False}
-
-
class Hex(Func):
pass
-class Histogram(AggFunc):
- arg_types = {"this": True, "bins": False}
-
-
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -3843,7 +4074,7 @@ class IfNull(Func):
class Initcap(Func):
- pass
+ arg_types = {"this": True, "expression": False}
class JSONKeyValue(Expression):
@@ -3861,6 +4092,14 @@ class JSONObject(Func):
}
+class OpenJSONColumnDef(Expression):
+ arg_types = {"this": True, "kind": True, "path": False, "as_json": False}
+
+
+class OpenJSON(Func):
+ arg_types = {"this": True, "path": False, "expressions": False}
+
+
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
@@ -3945,6 +4184,14 @@ class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
+ @property
+ def keys(self) -> t.List[Expression]:
+ return self.args["keys"].expressions
+
+ @property
+ def values(self) -> t.List[Expression]:
+ return self.args["values"].expressions
+
# https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html
class MatchAgainst(Func):
@@ -3993,17 +4240,6 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
-# Clickhouse-specific:
-# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
-class Quantiles(AggFunc):
- arg_types = {"parameters": True, "expressions": True}
- is_var_len_args = True
-
-
-class QuantileIf(AggFunc):
- arg_types = {"parameters": True, "expressions": True}
-
-
class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
@@ -4089,6 +4325,10 @@ class Substring(Func):
arg_types = {"this": True, "start": False, "length": False}
+class StandardHash(Func):
+ arg_types = {"this": True, "expression": False}
+
+
class StrPosition(Func):
arg_types = {
"this": True,
@@ -4328,15 +4568,19 @@ def maybe_parse(
return sql_or_expression.copy()
return sql_or_expression
+ if sql_or_expression is None:
+ raise ParseError(f"SQL cannot be None")
+
import sqlglot
sql = str(sql_or_expression)
if prefix:
sql = f"{prefix} {sql}"
+
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
-def _maybe_copy(instance, copy=True):
+def _maybe_copy(instance: E, copy: bool = True) -> E:
return instance.copy() if copy else instance
@@ -4383,16 +4627,18 @@ def _apply_child_list_builder(
instance = _maybe_copy(instance, copy)
parsed = []
for expression in expressions:
- if _is_wrong_expression(expression, into):
- expression = into(expressions=[expression])
- expression = maybe_parse(
- expression,
- into=into,
- dialect=dialect,
- prefix=prefix,
- **opts,
- )
- parsed.extend(expression.expressions)
+ if expression is not None:
+ if _is_wrong_expression(expression, into):
+ expression = into(expressions=[expression])
+
+ expression = maybe_parse(
+ expression,
+ into=into,
+ dialect=dialect,
+ prefix=prefix,
+ **opts,
+ )
+ parsed.extend(expression.expressions)
existing = instance.args.get(arg)
if append and existing:
@@ -4402,6 +4648,7 @@ def _apply_child_list_builder(
for k, v in (properties or {}).items():
child.set(k, v)
instance.set(arg, child)
+
return instance
@@ -4427,6 +4674,7 @@ def _apply_list_builder(
**opts,
)
for expression in expressions
+ if expression is not None
]
existing_expressions = inst.args.get(arg)
@@ -4463,25 +4711,59 @@ def _apply_conjunction_builder(
return inst
-def _combine(expressions, operator, dialect=None, copy=True, **opts):
- expressions = [
- condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
+def _apply_cte_builder(
+ instance: E,
+ alias: ExpOrStr,
+ as_: ExpOrStr,
+ recursive: t.Optional[bool] = None,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+) -> E:
+ alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts)
+ as_expression = maybe_parse(as_, dialect=dialect, **opts)
+ cte = CTE(this=as_expression, alias=alias_expression)
+ return _apply_child_list_builder(
+ cte,
+ instance=instance,
+ arg="with",
+ append=append,
+ copy=copy,
+ into=With,
+ properties={"recursive": recursive or False},
+ )
+
+
+def _combine(
+ expressions: t.Sequence[t.Optional[ExpOrStr]],
+ operator: t.Type[Connector],
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+) -> Expression:
+ conditions = [
+ condition(expression, dialect=dialect, copy=copy, **opts)
+ for expression in expressions
+ if expression is not None
]
- this = expressions[0]
- if expressions[1:]:
+
+ this, *rest = conditions
+ if rest:
this = _wrap(this, Connector)
- for expression in expressions[1:]:
+ for expression in rest:
this = operator(this=this, expression=_wrap(expression, Connector))
+
return this
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
- if isinstance(expression, kind):
- return Paren(this=expression)
- return expression
+ return Paren(this=expression) if isinstance(expression, kind) else expression
-def union(left, right, distinct=True, dialect=None, **opts):
+def union(
+ left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+) -> Union:
"""
Initializes a syntax tree from one UNION expression.
@@ -4490,15 +4772,16 @@ def union(left, right, distinct=True, dialect=None, **opts):
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
- left (str | Expression): the SQL code string corresponding to the left-hand side.
+ left: the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
- right (str | Expression): the SQL code string corresponding to the right-hand side.
+ right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
- distinct (bool): set the DISTINCT flag if and only if this is true.
- dialect (str): the dialect used to parse the input expression.
- opts (kwargs): other options to use to parse the input expressions.
+ distinct: set the DISTINCT flag if and only if this is true.
+ dialect: the dialect used to parse the input expression.
+ opts: other options to use to parse the input expressions.
+
Returns:
- Union: the syntax tree for the UNION expression.
+ The new Union instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
@@ -4506,7 +4789,9 @@ def union(left, right, distinct=True, dialect=None, **opts):
return Union(this=left, expression=right, distinct=distinct)
-def intersect(left, right, distinct=True, dialect=None, **opts):
+def intersect(
+ left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+) -> Intersect:
"""
Initializes a syntax tree from one INTERSECT expression.
@@ -4515,15 +4800,16 @@ def intersect(left, right, distinct=True, dialect=None, **opts):
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
- left (str | Expression): the SQL code string corresponding to the left-hand side.
+ left: the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
- right (str | Expression): the SQL code string corresponding to the right-hand side.
+ right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
- distinct (bool): set the DISTINCT flag if and only if this is true.
- dialect (str): the dialect used to parse the input expression.
- opts (kwargs): other options to use to parse the input expressions.
+ distinct: set the DISTINCT flag if and only if this is true.
+ dialect: the dialect used to parse the input expression.
+ opts: other options to use to parse the input expressions.
+
Returns:
- Intersect: the syntax tree for the INTERSECT expression.
+ The new Intersect instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
@@ -4531,7 +4817,9 @@ def intersect(left, right, distinct=True, dialect=None, **opts):
return Intersect(this=left, expression=right, distinct=distinct)
-def except_(left, right, distinct=True, dialect=None, **opts):
+def except_(
+ left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
+) -> Except:
"""
Initializes a syntax tree from one EXCEPT expression.
@@ -4540,15 +4828,16 @@ def except_(left, right, distinct=True, dialect=None, **opts):
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
- left (str | Expression): the SQL code string corresponding to the left-hand side.
+ left: the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
- right (str | Expression): the SQL code string corresponding to the right-hand side.
+ right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
- distinct (bool): set the DISTINCT flag if and only if this is true.
- dialect (str): the dialect used to parse the input expression.
- opts (kwargs): other options to use to parse the input expressions.
+ distinct: set the DISTINCT flag if and only if this is true.
+ dialect: the dialect used to parse the input expression.
+ opts: other options to use to parse the input expressions.
+
Returns:
- Except: the syntax tree for the EXCEPT statement.
+ The new Except instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
@@ -4578,7 +4867,7 @@ def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Selec
return Select().select(*expressions, dialect=dialect, **opts)
-def from_(*expressions, dialect=None, **opts) -> Select:
+def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
"""
Initializes a syntax tree from a FROM expression.
@@ -4587,9 +4876,9 @@ def from_(*expressions, dialect=None, **opts) -> Select:
'SELECT col1, col2 FROM tbl'
Args:
- *expressions (str | Expression): the SQL code string to parse as the FROM expressions of a
+ *expression: the SQL code string to parse as the FROM expressions of a
SELECT statement. If an Expression instance is passed, this is used as-is.
- dialect (str): the dialect used to parse the input expression (in the case that the
+ dialect: the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
@@ -4597,7 +4886,7 @@ def from_(*expressions, dialect=None, **opts) -> Select:
Returns:
Select: the syntax tree for the SELECT statement.
"""
- return Select().from_(*expressions, dialect=dialect, **opts)
+ return Select().from_(expression, dialect=dialect, **opts)
def update(
@@ -4680,7 +4969,54 @@ def delete(
return delete_expr
-def condition(expression, dialect=None, copy=True, **opts) -> Condition:
+def insert(
+ expression: ExpOrStr,
+ into: ExpOrStr,
+ columns: t.Optional[t.Sequence[ExpOrStr]] = None,
+ overwrite: t.Optional[bool] = None,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+) -> Insert:
+ """
+ Builds an INSERT statement.
+
+ Example:
+ >>> insert("VALUES (1, 2, 3)", "tbl").sql()
+ 'INSERT INTO tbl VALUES (1, 2, 3)'
+
+ Args:
+ expression: the sql string or expression of the INSERT statement
+ into: the tbl to insert data to.
+ columns: optionally the table's column names.
+ overwrite: whether to INSERT OVERWRITE or not.
+ dialect: the dialect used to parse the input expressions.
+ copy: whether or not to copy the expression.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ Insert: the syntax tree for the INSERT statement.
+ """
+ expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts)
+ this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)
+
+ if columns:
+ this = _apply_list_builder(
+ *columns,
+ instance=Schema(this=this),
+ arg="expressions",
+ into=Identifier,
+ copy=False,
+ dialect=dialect,
+ **opts,
+ )
+
+ return Insert(this=this, expression=expr, overwrite=overwrite)
+
+
+def condition(
+ expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts
+) -> Condition:
"""
Initialize a logical condition expression.
@@ -4695,18 +5031,18 @@ def condition(expression, dialect=None, copy=True, **opts) -> Condition:
'SELECT * FROM tbl WHERE x = 1 AND y = 1'
Args:
- *expression (str | Expression): the SQL code string to parse.
+ *expression: the SQL code string to parse.
If an Expression instance is passed, this is used as-is.
- dialect (str): the dialect used to parse the input expression (in the case that the
+ dialect: the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
- copy (bool): Whether or not to copy `expression` (only applies to expressions).
+ copy: Whether or not to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
Returns:
- Condition: the expression
+ The new Condition instance
"""
- return maybe_parse( # type: ignore
+ return maybe_parse(
expression,
into=Condition,
dialect=dialect,
@@ -4715,7 +5051,9 @@ def condition(expression, dialect=None, copy=True, **opts) -> Condition:
)
-def and_(*expressions, dialect=None, copy=True, **opts) -> And:
+def and_(
+ *expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts
+) -> Condition:
"""
Combine multiple conditions with an AND logical operator.
@@ -4724,19 +5062,21 @@ def and_(*expressions, dialect=None, copy=True, **opts) -> And:
'x = 1 AND (y = 1 AND z = 1)'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): whether or not to copy `expressions` (only applies to Expressions).
+ dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
And: the new condition
"""
- return _combine(expressions, And, dialect, copy=copy, **opts)
+ return t.cast(Condition, _combine(expressions, And, dialect, copy=copy, **opts))
-def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
+def or_(
+ *expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts
+) -> Condition:
"""
Combine multiple conditions with an OR logical operator.
@@ -4745,19 +5085,19 @@ def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
'x = 1 OR (y = 1 OR z = 1)'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
- dialect (str): the dialect used to parse the input expression.
- copy (bool): whether or not to copy `expressions` (only applies to Expressions).
+ dialect: the dialect used to parse the input expression.
+ copy: whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
Or: the new condition
"""
- return _combine(expressions, Or, dialect, copy=copy, **opts)
+ return t.cast(Condition, _combine(expressions, Or, dialect, copy=copy, **opts))
-def not_(expression, dialect=None, copy=True, **opts) -> Not:
+def not_(expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@@ -4766,13 +5106,14 @@ def not_(expression, dialect=None, copy=True, **opts) -> Not:
"NOT this_suit = 'black'"
Args:
- expression (str | Expression): the SQL code strings to parse.
+ expression: the SQL code string to parse.
If an Expression instance is passed, this is used as-is.
- dialect (str): the dialect used to parse the input expression.
+ dialect: the dialect used to parse the input expression.
+ copy: whether to copy the expression or not.
**opts: other options to use to parse the input expressions.
Returns:
- Not: the new condition
+ The new condition.
"""
this = condition(
expression,
@@ -4783,29 +5124,47 @@ def not_(expression, dialect=None, copy=True, **opts) -> Not:
return Not(this=_wrap(this, Connector))
-def paren(expression, copy=True) -> Paren:
- return Paren(this=_maybe_copy(expression, copy))
+def paren(expression: ExpOrStr, copy: bool = True) -> Paren:
+ """
+ Wrap an expression in parentheses.
+
+ Example:
+ >>> paren("5 + 3").sql()
+ '(5 + 3)'
+
+ Args:
+ expression: the SQL code string to parse.
+ If an Expression instance is passed, this is used as-is.
+ copy: whether to copy the expression or not.
+
+ Returns:
+ The wrapped expression.
+ """
+ return Paren(this=maybe_parse(expression, copy=copy))
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
-def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None:
+def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
...
@t.overload
-def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier:
+def to_identifier(
+ name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
+) -> Identifier:
...
-def to_identifier(name, quoted=None):
+def to_identifier(name, quoted=None, copy=True):
"""Builds an identifier.
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
+ copy: Whether or not to copy a passed in Identefier node.
Returns:
The identifier ast node.
@@ -4815,7 +5174,7 @@ def to_identifier(name, quoted=None):
return None
if isinstance(name, Identifier):
- identifier = name
+ identifier = _maybe_copy(name, copy)
elif isinstance(name, str):
identifier = Identifier(
this=name,
@@ -4858,13 +5217,17 @@ def to_table(sql_path: None, **kwargs) -> None:
...
-def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
+def to_table(
+ sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs
+) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
If a table is passed in then that table is returned.
Args:
sql_path: a `[catalog].[schema].[table]` string.
+ dialect: the source dialect according to which the table name will be parsed.
+ kwargs: the kwargs to instantiate the resulting `Table` expression with.
Returns:
A table expression.
@@ -4874,8 +5237,12 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
- catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
- return Table(this=table_name, db=db, catalog=catalog, **kwargs)
+ table = maybe_parse(sql_path, into=Table, dialect=dialect)
+ if table:
+ for k, v in kwargs.items():
+ table.set(k, v)
+
+ return table
def to_column(sql_path: str | Column, **kwargs) -> Column:
@@ -4902,6 +5269,7 @@ def alias_(
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
dialect: DialectType = None,
+ copy: bool = True,
**opts,
):
"""Create an Alias expression.
@@ -4921,18 +5289,17 @@ def alias_(
table: Whether or not to create a table alias, can also be a list of columns.
quoted: whether or not to quote the alias
dialect: the dialect used to parse the input expression.
+ copy: Whether or not to copy the expression.
**opts: other options to use to parse the input expressions.
Returns:
Alias: the aliased expression
"""
- exp = maybe_parse(expression, dialect=dialect, **opts)
+ exp = maybe_parse(expression, dialect=dialect, copy=copy, **opts)
alias = to_identifier(alias, quoted=quoted)
if table:
table_alias = TableAlias(this=alias)
-
- exp = exp.copy() if isinstance(expression, Expression) else exp
exp.set("alias", table_alias)
if not isinstance(table, bool):
@@ -4948,13 +5315,17 @@ def alias_(
# [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls
if "alias" in exp.arg_types and not isinstance(exp, Window):
- exp = exp.copy()
exp.set("alias", alias)
return exp
return Alias(this=exp, alias=alias)
-def subquery(expression, alias=None, dialect=None, **opts):
+def subquery(
+ expression: ExpOrStr,
+ alias: t.Optional[Identifier | str] = None,
+ dialect: DialectType = None,
+ **opts,
+) -> Select:
"""
Build a subquery expression.
@@ -4963,14 +5334,14 @@ def subquery(expression, alias=None, dialect=None, **opts):
'SELECT x FROM (SELECT x FROM tbl) AS bar'
Args:
- expression (str | Expression): the SQL code strings to parse.
+ expression: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
- alias (str | Expression): the alias name to use.
- dialect (str): the dialect used to parse the input expression.
+ alias: the alias name to use.
+ dialect: the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
Returns:
- Select: a new select with the subquery expression included
+ A new Select instance with the subquery expression included.
"""
expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias)
@@ -4988,13 +5359,14 @@ def column(
Build a Column.
Args:
- col: column name
- table: table name
- db: db name
- catalog: catalog name
- quoted: whether or not to force quote each part
+ col: Column name.
+ table: Table name.
+ db: Database name.
+ catalog: Catalog name.
+ quoted: Whether to force quotes on the column's identifiers.
+
Returns:
- Column: column instance
+ The new Column instance.
"""
return Column(
this=to_identifier(col, quoted=quoted),
@@ -5016,22 +5388,30 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca
to: The datatype to cast to.
Returns:
- A cast node.
+ The new Cast instance.
"""
expression = maybe_parse(expression, **opts)
return Cast(this=expression, to=DataType.build(to, **opts))
-def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
+def table_(
+ table: Identifier | str,
+ db: t.Optional[Identifier | str] = None,
+ catalog: t.Optional[Identifier | str] = None,
+ quoted: t.Optional[bool] = None,
+ alias: t.Optional[Identifier | str] = None,
+) -> Table:
"""Build a Table.
Args:
- table (str | Expression): column name
- db (str | Expression): db name
- catalog (str | Expression): catalog name
+ table: Table name.
+ db: Database name.
+ catalog: Catalog name.
+ quote: Whether to force quotes on the table's identifiers.
+ alias: Table's alias.
Returns:
- Table: table instance
+ The new Table instance.
"""
return Table(
this=to_identifier(table, quoted=quoted),
@@ -5160,7 +5540,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
raise ValueError(f"Cannot convert {value}")
-def replace_children(expression, fun, *args, **kwargs):
+def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -> None:
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
@@ -5182,7 +5562,7 @@ def replace_children(expression, fun, *args, **kwargs):
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
-def column_table_names(expression):
+def column_table_names(expression: Expression) -> t.List[str]:
"""
Return all table names referenced through columns in an expression.
@@ -5192,19 +5572,19 @@ def column_table_names(expression):
['c', 'a']
Args:
- expression (sqlglot.Expression): expression to find table names
+ expression: expression to find table names.
Returns:
- list: A list of unique names
+ A list of unique names.
"""
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
-def table_name(table) -> str:
+def table_name(table: Table | str) -> str:
"""Get the full name of a table as a string.
Args:
- table (exp.Table | str): table expression node or string.
+ table: table expression node or string.
Examples:
>>> from sqlglot import exp, parse_one
@@ -5220,23 +5600,15 @@ def table_name(table) -> str:
if not table:
raise ValueError(f"Cannot parse {table}")
- return ".".join(
- part
- for part in (
- table.text("catalog"),
- table.text("db"),
- table.name,
- )
- if part
- )
+ return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part)
-def replace_tables(expression, mapping):
+def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
"""Replace all tables in expression according to the mapping.
Args:
- expression (sqlglot.Expression): expression node to be transformed and replaced.
- mapping (Dict[str, str]): mapping of table names.
+ expression: expression node to be transformed and replaced.
+ mapping: mapping of table names.
Examples:
>>> from sqlglot import exp, parse_one
@@ -5247,7 +5619,7 @@ def replace_tables(expression, mapping):
The mapped expression.
"""
- def _replace_tables(node):
+ def _replace_tables(node: Expression) -> Expression:
if isinstance(node, Table):
new_name = mapping.get(table_name(node))
if new_name:
@@ -5260,11 +5632,11 @@ def replace_tables(expression, mapping):
return expression.transform(_replace_tables)
-def replace_placeholders(expression, *args, **kwargs):
+def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
"""Replace placeholders in an expression.
Args:
- expression (sqlglot.Expression): expression node to be transformed and replaced.
+ expression: expression node to be transformed and replaced.
args: positional names that will substitute unnamed placeholders in the given order.
kwargs: keyword arguments that will substitute named placeholders.
@@ -5280,7 +5652,7 @@ def replace_placeholders(expression, *args, **kwargs):
The mapped expression.
"""
- def _replace_placeholders(node, args, **kwargs):
+ def _replace_placeholders(node: Expression, args, **kwargs) -> Expression:
if isinstance(node, Placeholder):
if node.name:
new_name = kwargs.get(node.name)
@@ -5378,21 +5750,21 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
return function
-def true():
+def true() -> Boolean:
"""
Returns a true Boolean expression.
"""
return Boolean(this=True)
-def false():
+def false() -> Boolean:
"""
Returns a false Boolean expression.
"""
return Boolean(this=False)
-def null():
+def null() -> Null:
"""
Returns a Null expression.
"""
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index d7dcea0..f1ec398 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -31,6 +31,8 @@ class Generator:
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
+ raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
+ raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
@@ -76,11 +78,12 @@ class Generator:
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
- exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS",
+ exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
+ exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
- exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
+ exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
@@ -133,6 +136,15 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
+ # Whether a table is allowed to be renamed with a db
+ RENAME_TABLE_WITH_DB = True
+
+ # The separator for grouping sets and rollups
+ GROUPINGS_SEP = ","
+
+ # The string used for creating index on a table
+ INDEX_ON = "ON"
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -167,7 +179,6 @@ class Generator:
PARAMETER_TOKEN = "@"
PROPERTIES_LOCATION = {
- exp.AfterJournalProperty: exp.Properties.Location.POST_NAME,
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
@@ -196,7 +207,9 @@ class Generator:
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
+ exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
+ exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
@@ -204,13 +217,15 @@ class Generator:
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
+ exp.Set: exp.Properties.Location.POST_SCHEMA,
+ exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
- exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
+ exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
@@ -221,7 +236,7 @@ class Generator:
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
- UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
+ UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@@ -239,6 +254,8 @@ class Generator:
"hex_end",
"byte_start",
"byte_end",
+ "raw_start",
+ "raw_end",
"identify",
"normalize",
"string_escape",
@@ -276,6 +293,8 @@ class Generator:
hex_end=None,
byte_start=None,
byte_end=None,
+ raw_start=None,
+ raw_end=None,
identify=False,
normalize=False,
string_escape=None,
@@ -308,6 +327,8 @@ class Generator:
self.hex_end = hex_end
self.byte_start = byte_start
self.byte_end = byte_end
+ self.raw_start = raw_start
+ self.raw_end = raw_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
@@ -399,7 +420,11 @@ class Generator:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
- return f"{comments_sql}{self.sep()}{sql}"
+ return (
+ f"{self.sep()}{comments_sql}{sql}"
+ if sql[0].isspace()
+ else f"{comments_sql}{self.sep()}{sql}"
+ )
return f"{sql} {comments_sql}"
@@ -567,7 +592,9 @@ class Generator:
) -> str:
this = ""
if expression.this is not None:
- this = " ALWAYS " if expression.this else " BY DEFAULT "
+ on_null = "ON NULL " if expression.args.get("on_null") else ""
+ this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}"
+
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
@@ -578,14 +605,20 @@ class Generator:
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
cycle = expression.args.get("cycle")
cycle_sql = ""
+
if cycle is not None:
cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
+
sequence_opts = ""
if start or increment or cycle_sql:
sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
sequence_opts = f" ({sequence_opts.strip()})"
- return f"GENERATED{this}AS IDENTITY{sequence_opts}"
+
+ expr = self.sql(expression, "expression")
+ expr = f"({expr})" if expr else "IDENTITY"
+
+ return f"GENERATED{this}AS {expr}{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@@ -596,8 +629,10 @@ class Generator:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
- def uniquecolumnconstraint_sql(self, _) -> str:
- return "UNIQUE"
+ def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ return f"UNIQUE{this}"
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
@@ -653,33 +688,9 @@ class Generator:
prefix=" ",
)
- indexes = expression.args.get("indexes")
- if indexes:
- indexes_sql: t.List[str] = []
- for index in indexes:
- ind_unique = " UNIQUE" if index.args.get("unique") else ""
- ind_primary = " PRIMARY" if index.args.get("primary") else ""
- ind_amp = " AMP" if index.args.get("amp") else ""
- ind_name = f" {index.name}" if index.name else ""
- ind_columns = (
- f' ({self.expressions(index, key="columns", flat=True)})'
- if index.args.get("columns")
- else ""
- )
- ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
-
- if indexes_sql:
- indexes_sql.append(ind_sql)
- else:
- indexes_sql.append(
- f"{ind_sql}{postindex_props_sql}"
- if index.args.get("primary")
- else f"{postindex_props_sql}{ind_sql}"
- )
-
- index_sql = "".join(indexes_sql)
- else:
- index_sql = postindex_props_sql
+ indexes = self.expressions(expression, key="indexes", indent=False, sep=" ")
+ indexes = f" {indexes}" if indexes else ""
+ index_sql = indexes + postindex_props_sql
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
@@ -711,9 +722,23 @@ class Generator:
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
- expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}"
+ clone = self.sql(expression, "clone")
+ clone = f" {clone}" if clone else ""
+
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
+ def clone_sql(self, expression: exp.Clone) -> str:
+ this = self.sql(expression, "this")
+ when = self.sql(expression, "when")
+
+ if when:
+ kind = self.sql(expression, "kind")
+ expr = self.sql(expression, "expression")
+ return f"CLONE {this} {when} ({kind} => {expr})"
+
+ return f"CLONE {this}"
+
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
@@ -757,6 +782,17 @@ class Generator:
return f"{self.byte_start}{this}{self.byte_end}"
return this
+ def rawstring_sql(self, expression: exp.RawString) -> str:
+ if self.raw_start:
+ return f"{self.raw_start}{expression.name}{self.raw_end}"
+ return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
+
+ def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
+ this = self.sql(expression, "this")
+ specifier = self.sql(expression, "expression")
+ specifier = f" {specifier}" if specifier else ""
+ return f"{this}{specifier}"
+
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
@@ -768,7 +804,8 @@ class Generator:
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
- values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
+ values = self.expressions(expression, key="values", flat=True)
+ values = f"{delimiters[0]}{values}{delimiters[1]}"
else:
nested = f"({interior})"
@@ -836,10 +873,17 @@ class Generator:
return ""
def index_sql(self, expression: exp.Index) -> str:
- this = self.sql(expression, "this")
+ unique = "UNIQUE " if expression.args.get("unique") else ""
+ primary = "PRIMARY " if expression.args.get("primary") else ""
+ amp = "AMP " if expression.args.get("amp") else ""
+ name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
- columns = self.sql(expression, "columns")
- return f"{this} ON {table} {columns}"
+ table = f"{self.INDEX_ON} {table} " if table else ""
+ index = "INDEX " if not table else ""
+ columns = self.expressions(expression, key="columns", flat=True)
+ partition_by = self.expressions(expression, key="partition_by", flat=True)
+ partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
+ return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
@@ -861,8 +905,9 @@ class Generator:
output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
return self.sep().join((input_format, output_format))
- def national_sql(self, expression: exp.National) -> str:
- return f"N{self.sql(expression, 'this')}"
+ def national_sql(self, expression: exp.National, prefix: str = "N") -> str:
+ string = self.sql(exp.Literal.string(expression.name))
+ return f"{prefix}{string}"
def partition_sql(self, expression: exp.Partition) -> str:
return f"PARTITION({self.expressions(expression)})"
@@ -955,23 +1000,18 @@ class Generator:
def journalproperty_sql(self, expression: exp.JournalProperty) -> str:
no = "NO " if expression.args.get("no") else ""
+ local = expression.args.get("local")
+ local = f"{local} " if local else ""
dual = "DUAL " if expression.args.get("dual") else ""
before = "BEFORE " if expression.args.get("before") else ""
- return f"{no}{dual}{before}JOURNAL"
+ after = "AFTER " if expression.args.get("after") else ""
+ return f"{no}{local}{dual}{before}{after}JOURNAL"
def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str:
freespace = self.sql(expression, "this")
percent = " PERCENT" if expression.args.get("percent") else ""
return f"FREESPACE={freespace}{percent}"
- def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str:
- no = "NO " if expression.args.get("no") else ""
- dual = "DUAL " if expression.args.get("dual") else ""
- local = ""
- if expression.args.get("local") is not None:
- local = "LOCAL " if expression.args.get("local") else "NOT LOCAL "
- return f"{no}{dual}{local}AFTER JOURNAL"
-
def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str:
if expression.args.get("default"):
property = "DEFAULT"
@@ -992,19 +1032,19 @@ class Generator:
def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str:
default = expression.args.get("default")
- min = expression.args.get("min")
- if default is not None or min is not None:
+ minimum = expression.args.get("minimum")
+ maximum = expression.args.get("maximum")
+ if default or minimum or maximum:
if default:
- property = "DEFAULT"
- elif min:
- property = "MINIMUM"
+ prop = "DEFAULT"
+ elif minimum:
+ prop = "MINIMUM"
else:
- property = "MAXIMUM"
- return f"{property} DATABLOCKSIZE"
- else:
- units = expression.args.get("units")
- units = f" {units}" if units else ""
- return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
+ prop = "MAXIMUM"
+ return f"{prop} DATABLOCKSIZE"
+ units = expression.args.get("units")
+ units = f" {units}" if units else ""
+ return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str:
autotemp = expression.args.get("autotemp")
@@ -1014,16 +1054,16 @@ class Generator:
never = expression.args.get("never")
if autotemp is not None:
- property = f"AUTOTEMP({self.expressions(autotemp)})"
+ prop = f"AUTOTEMP({self.expressions(autotemp)})"
elif always:
- property = "ALWAYS"
+ prop = "ALWAYS"
elif default:
- property = "DEFAULT"
+ prop = "DEFAULT"
elif manual:
- property = "MANUAL"
+ prop = "MANUAL"
elif never:
- property = "NEVER"
- return f"BLOCKCOMPRESSION={property}"
+ prop = "NEVER"
+ return f"BLOCKCOMPRESSION={prop}"
def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str:
no = expression.args.get("no")
@@ -1138,21 +1178,24 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
- hints = self.expressions(expression, key="hints", sep=", ", flat=True)
+ hints = self.expressions(expression, key="hints", flat=True)
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
- laterals = self.expressions(expression, key="laterals", sep="")
+ pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
+ pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="")
- pivots = self.expressions(expression, key="pivots", sep="")
+ laterals = self.expressions(expression, key="laterals", sep="")
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
- return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
+ return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.alias_post_tablesample and expression.this.alias:
- this = self.sql(expression.this, "this")
+ table = expression.this.copy()
+ table.set("alias", None)
+ this = self.sql(table)
alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
@@ -1177,14 +1220,22 @@ class Generator:
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
def pivot_sql(self, expression: exp.Pivot) -> str:
- this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True)
+
+ if expression.this:
+ this = self.sql(expression, "this")
+ on = f"{self.seg('ON')} {expressions}"
+ using = self.expressions(expression, key="using", flat=True)
+ using = f"{self.seg('USING')} {using}" if using else ""
+ group = self.sql(expression, "group")
+ return f"PIVOT {this}{on}{using}{group}"
+
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
- expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field")
- return f"{this} {direction}({expressions} FOR {field}){alias}"
+ return f"{direction}({expressions} FOR {field}){alias}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
@@ -1218,8 +1269,7 @@ class Generator:
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
def from_sql(self, expression: exp.From) -> str:
- expressions = self.expressions(expression, flat=True)
- return f"{self.seg('FROM')} {expressions}"
+ return f"{self.seg('FROM')} {self.sql(expression, 'this')}"
def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
@@ -1242,10 +1292,16 @@ class Generator:
rollup_sql = self.expressions(expression, key="rollup", indent=False)
rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
- groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",")
+ groupings = csv(
+ grouping_sets,
+ cube_sql,
+ rollup_sql,
+ self.seg("WITH TOTALS") if expression.args.get("totals") else "",
+ sep=self.GROUPINGS_SEP,
+ )
if expression.args.get("expressions") and groupings:
- group_by = f"{group_by},"
+ group_by = f"{group_by}{self.GROUPINGS_SEP}"
return f"{group_by}{groupings}"
@@ -1254,18 +1310,16 @@ class Generator:
return f"{self.seg('HAVING')}{self.sep()}{this}"
def join_sql(self, expression: exp.Join) -> str:
- op_sql = self.seg(
- " ".join(
- op
- for op in (
- "NATURAL" if expression.args.get("natural") else None,
- expression.side,
- expression.kind,
- expression.hint if self.JOIN_HINTS else None,
- "JOIN",
- )
- if op
+ op_sql = " ".join(
+ op
+ for op in (
+ "NATURAL" if expression.args.get("natural") else None,
+ "GLOBAL" if expression.args.get("global") else None,
+ expression.side,
+ expression.kind,
+ expression.hint if self.JOIN_HINTS else None,
)
+ if op
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@@ -1273,6 +1327,8 @@ class Generator:
if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))
+ this_sql = self.sql(expression, "this")
+
if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " "
@@ -1280,10 +1336,11 @@ class Generator:
on_sql = f"{space}USING ({on_sql})"
else:
on_sql = f"{space}ON {on_sql}"
+ elif not op_sql:
+ return f", {this_sql}"
- expression_sql = self.sql(expression, "expression")
- this_sql = self.sql(expression, "this")
- return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
+ op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
+ return f"{self.seg(op_sql)} {this_sql}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
@@ -1336,12 +1393,22 @@ class Generator:
return f"PRAGMA {self.sql(expression, 'this')}"
def lock_sql(self, expression: exp.Lock) -> str:
- if self.LOCKING_READS_SUPPORTED:
- lock_type = "UPDATE" if expression.args["update"] else "SHARE"
- return self.seg(f"FOR {lock_type}")
+ if not self.LOCKING_READS_SUPPORTED:
+ self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
+ return ""
- self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
- return ""
+ lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE"
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" OF {expressions}" if expressions else ""
+ wait = expression.args.get("wait")
+
+ if wait is not None:
+ if isinstance(wait, exp.Literal):
+ wait = f" WAIT {self.sql(wait)}"
+ else:
+ wait = " NOWAIT" if wait else " SKIP LOCKED"
+
+ return f"{lock_type}{expressions}{wait or ''}"
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
@@ -1460,27 +1527,33 @@ class Generator:
return csv(
*sqls,
- *[self.sql(sql) for sql in expression.args.get("joins") or []],
+ *[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "match"),
- *[self.sql(sql) for sql in expression.args.get("laterals") or []],
+ *[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
- self.sql(expression, "qualify"),
- self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
- if expression.args.get("windows")
- else "",
- self.sql(expression, "distribute"),
- self.sql(expression, "sort"),
- self.sql(expression, "cluster"),
+ *self.after_having_modifiers(expression),
self.sql(expression, "order"),
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
- self.sql(expression, "lock"),
- self.sql(expression, "sample"),
+ *self.after_limit_modifiers(expression),
sep="",
)
+ def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
+ return [
+ self.sql(expression, "qualify"),
+ self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
+ if expression.args.get("windows")
+ else "",
+ ]
+
+ def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
+ locks = self.expressions(expression, key="locks", sep=" ")
+ locks = f" {locks}" if locks else ""
+ return [locks, self.sql(expression, "sample")]
+
def select_sql(self, expression: exp.Select) -> str:
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
@@ -1529,13 +1602,10 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
- sql = self.query_modifiers(
- expression,
- self.wrap(expression),
- alias,
- self.expressions(expression, key="pivots", sep=" "),
- )
+ pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
+ pivots = f" {pivots}" if pivots else ""
+ sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)
def qualify_sql(self, expression: exp.Qualify) -> str:
@@ -1712,10 +1782,6 @@ class Generator:
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
- def unique_sql(self, expression: exp.Unique) -> str:
- columns = self.expressions(expression, key="expressions")
- return f"UNIQUE ({columns})"
-
def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
@@ -1745,6 +1811,26 @@ class Generator:
encoding = f" ENCODING {encoding}" if encoding else ""
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
+ def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
+ this = self.sql(expression, "this")
+ kind = self.sql(expression, "kind")
+ path = self.sql(expression, "path")
+ path = f" {path}" if path else ""
+ as_json = " AS JSON" if expression.args.get("as_json") else ""
+ return f"{this} {kind}{path}{as_json}"
+
+ def openjson_sql(self, expression: exp.OpenJSON) -> str:
+ this = self.sql(expression, "this")
+ path = self.sql(expression, "path")
+ path = f", {path}" if path else ""
+ expressions = self.expressions(expression)
+ with_ = (
+ f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}"
+ if expressions
+ else ""
+ )
+ return f"OPENJSON({this}{path}){with_}"
+
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
@@ -1773,7 +1859,7 @@ class Generator:
if self.SINGLE_STRING_INTERVAL:
this = expression.this.name if expression.this else ""
- return f"INTERVAL '{this}{unit}'"
+ return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}"
this = self.sql(expression, "this")
if this:
@@ -1883,6 +1969,28 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
+ def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str:
+ this = self.sql(expression, "this")
+ delete = " DELETE" if expression.args.get("delete") else ""
+ recompress = self.sql(expression, "recompress")
+ recompress = f" RECOMPRESS {recompress}" if recompress else ""
+ to_disk = self.sql(expression, "to_disk")
+ to_disk = f" TO DISK {to_disk}" if to_disk else ""
+ to_volume = self.sql(expression, "to_volume")
+ to_volume = f" TO VOLUME {to_volume}" if to_volume else ""
+ return f"{this}{delete}{recompress}{to_disk}{to_volume}"
+
+ def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str:
+ where = self.sql(expression, "where")
+ group = self.sql(expression, "group")
+ aggregates = self.expressions(expression, key="aggregates")
+ aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else ""
+
+ if not (where or group or aggregates) and len(expression.expressions) == 1:
+ return f"TTL {self.expressions(expression, flat=True)}"
+
+ return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}"
+
def transaction_sql(self, expression: exp.Transaction) -> str:
return "BEGIN"
@@ -1919,6 +2027,11 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT"
def renametable_sql(self, expression: exp.RenameTable) -> str:
+ if not self.RENAME_TABLE_WITH_DB:
+ # Remove db from tables
+ expression = expression.transform(
+ lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
+ )
this = self.sql(expression, "this")
return f"RENAME TO {this}"
@@ -2208,3 +2321,12 @@ class Generator:
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")
return self.sql(exp.cast(expression.this, "text"))
+
+
+def cached_generator(
+ cache: t.Optional[t.Dict[int, str]] = None
+) -> t.Callable[[exp.Expression], str]:
+ """Returns a cached generator."""
+ cache = {} if cache is None else cache
+ generator = Generator(normalize=True, identify="safe")
+ return lambda e: generator.generate(e, cache)
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index b2f0520..4215fee 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -9,14 +9,14 @@ from collections.abc import Collection
from contextlib import contextmanager
from copy import copy
from enum import Enum
+from itertools import count
if t.TYPE_CHECKING:
from sqlglot import exp
+ from sqlglot._typing import E, T
+ from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression
- T = t.TypeVar("T")
- E = t.TypeVar("E", bound=Expression)
-
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
@@ -25,7 +25,7 @@ logger = logging.getLogger("sqlglot")
class AutoName(Enum):
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
- def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
+ def _generate_next_value_(name, _start, _count, _last_values):
return name
@@ -92,7 +92,7 @@ def ensure_collection(value):
)
-def csv(*args, sep: str = ", ") -> str:
+def csv(*args: str, sep: str = ", ") -> str:
"""
Formats any number of string arguments as CSV.
@@ -304,9 +304,18 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
return new
+def name_sequence(prefix: str) -> t.Callable[[], str]:
+ """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
+ sequence = count()
+ return lambda: f"{prefix}{next(sequence)}"
+
+
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
- return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
+ return {
+ **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
+ **kwargs,
+ }
def split_num_words(
@@ -381,15 +390,6 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
yield value
-def count_params(function: t.Callable) -> int:
- """
- Returns the number of formal parameters expected by a function, without counting "self"
- and "cls", in case of instance and class methods, respectively.
- """
- count = function.__code__.co_argcount
- return count - 1 if inspect.ismethod(function) else count
-
-
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
@@ -430,12 +430,23 @@ def first(it: t.Iterable[T]) -> T:
return next(i for i in it)
-def should_identify(text: str, identify: str | bool) -> bool:
+def case_sensitive(text: str, dialect: DialectType) -> bool:
+ """Checks if text contains any case sensitive characters depending on dialect."""
+ from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
+
+ unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
+ return any(unsafe(char) for char in text)
+
+
+def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
"""Checks if text should be identified given an identify option.
Args:
text: the text to check.
- identify: "always" | True - always returns true, "safe" - true if no upper case
+ identify:
+ "always" or `True`: always returns true.
+ "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
+ dialect: the dialect to use in order to decide whether a text should be identified.
Returns:
Whether or not a string should be identified.
@@ -443,5 +454,5 @@ def should_identify(text: str, identify: str | bool) -> bool:
if identify is True or identify == "always":
return True
if identify == "safe":
- return not any(char.isupper() for char in text)
+ return not case_sensitive(text, dialect)
return False
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 0eac870..04a8073 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -5,10 +5,8 @@ import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
-from sqlglot.optimizer import Scope, build_scope, optimize
-from sqlglot.optimizer.expand_laterals import expand_laterals
-from sqlglot.optimizer.qualify_columns import qualify_columns
-from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.errors import SqlglotError
+from sqlglot.optimizer import Scope, build_scope, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@@ -40,8 +38,8 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
- rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
dialect: DialectType = None,
+ **kwargs,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@@ -50,8 +48,8 @@ def lineage(
sql: The SQL string or expression.
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
- rules: Optimizer rules to apply, by default only qualifying tables and columns.
dialect: The dialect of input SQL.
+ **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
@@ -68,8 +66,17 @@ def lineage(
},
)
- optimized = optimize(expression, schema=schema, rules=rules)
- scope = build_scope(optimized)
+ qualified = qualify.qualify(
+ expression,
+ dialect=dialect,
+ schema=schema,
+ **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
+ )
+
+ scope = build_scope(qualified)
+
+ if not scope:
+ raise SqlglotError("Cannot build lineage, sql must be SELECT")
def to_node(
column_name: str,
@@ -109,10 +116,7 @@ def lineage(
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
# => "x", SELECT x FROM foo
- source = optimize(
- scope.expression.select(select, append=False), schema=schema, rules=rules
- )
- select = source.selects[0]
+ source = t.cast(exp.Expression, scope.expression.select(select, append=False))
else:
source = scope.expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index ef929ac..da2fce8 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -3,10 +3,9 @@ from __future__ import annotations
import itertools
from sqlglot import exp
-from sqlglot.helper import should_identify
-def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
+def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@@ -14,19 +13,14 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr
Args:
expression: The expression to canonicalize.
- identify: Whether or not to force identify identifier.
"""
- exp.replace_children(expression, canonicalize, identify=identify)
+ exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
- if isinstance(expression, exp.Identifier):
- if should_identify(expression.this, identify):
- expression.set("quoted", True)
-
return expression
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py
index 7b862c6..6f1865c 100644
--- a/sqlglot/optimizer/eliminate_ctes.py
+++ b/sqlglot/optimizer/eliminate_ctes.py
@@ -19,24 +19,25 @@ def eliminate_ctes(expression):
"""
root = build_scope(expression)
- ref_count = root.ref_count()
-
- # Traverse the scope tree in reverse so we can remove chains of unused CTEs
- for scope in reversed(list(root.traverse())):
- if scope.is_cte:
- count = ref_count[id(scope)]
- if count <= 0:
- cte_node = scope.expression.parent
- with_node = cte_node.parent
- cte_node.pop()
-
- # Pop the entire WITH clause if this is the last CTE
- if len(with_node.expressions) <= 0:
- with_node.pop()
-
- # Decrement the ref count for all sources this CTE selects from
- for _, source in scope.selected_sources.values():
- if isinstance(source, Scope):
- ref_count[id(source)] -= 1
+ if root:
+ ref_count = root.ref_count()
+
+ # Traverse the scope tree in reverse so we can remove chains of unused CTEs
+ for scope in reversed(list(root.traverse())):
+ if scope.is_cte:
+ count = ref_count[id(scope)]
+ if count <= 0:
+ cte_node = scope.expression.parent
+ with_node = cte_node.parent
+ cte_node.pop()
+
+ # Pop the entire WITH clause if this is the last CTE
+ if len(with_node.expressions) <= 0:
+ with_node.pop()
+
+ # Decrement the ref count for all sources this CTE selects from
+ for _, source in scope.selected_sources.values():
+ if isinstance(source, Scope):
+ ref_count[id(source)] -= 1
return expression
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index a39fe96..84f50e9 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -16,9 +16,9 @@ def eliminate_subqueries(expression):
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
- 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
+ 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Args:
expression (sqlglot.Expression): expression
@@ -32,6 +32,9 @@ def eliminate_subqueries(expression):
root = build_scope(expression)
+ if not root:
+ return expression
+
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
@@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken):
# Try to maintain the selections
expressions = scope.selects
selects = [
- exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
+ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
for e in expressions
if e.alias_or_name
]
@@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken):
if len(selects) != len(expressions):
selects = ["*"]
- scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
+ scope.expression.replace(
+ exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
+ )
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
@@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
+ # This ensures we don't drop the "pivot" arg from a pivoted subquery
+ if scope.parent.pivots:
+ return None
+
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
@@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken):
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
- new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
+ new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
table.replace(new_table)
return cte
diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py
deleted file mode 100644
index 5b2f706..0000000
--- a/sqlglot/optimizer/expand_laterals.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-from sqlglot import exp
-
-
-def expand_laterals(expression: exp.Expression) -> exp.Expression:
- """
- Expand lateral column alias references.
-
- This assumes `qualify_columns` as already run.
-
- Example:
- >>> import sqlglot
- >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
- >>> expression = sqlglot.parse_one(sql)
- >>> expand_laterals(expression).sql()
- 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
-
- Args:
- expression: expression to optimize
- Returns:
- optimized expression
- """
- for select in expression.find_all(exp.Select):
- alias_to_expression: t.Dict[str, exp.Expression] = {}
- for projection in select.expressions:
- for column in projection.find_all(exp.Column):
- if not column.table and column.name in alias_to_expression:
- column.replace(alias_to_expression[column.name].copy())
- if isinstance(projection, exp.Alias):
- alias_to_expression[projection.alias] = projection.this
- return expression
diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py
deleted file mode 100644
index 86f0c2d..0000000
--- a/sqlglot/optimizer/expand_multi_table_selects.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from sqlglot import exp
-
-
-def expand_multi_table_selects(expression):
- """
- Replace multiple FROM expressions with JOINs.
-
- Example:
- >>> from sqlglot import parse_one
- >>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql()
- 'SELECT * FROM x CROSS JOIN y'
- """
- for from_ in expression.find_all(exp.From):
- parent = from_.parent
-
- for query in from_.expressions[1:]:
- parent.join(
- query,
- join_type="CROSS",
- copy=False,
- )
- from_.expressions.remove(query)
-
- return expression
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index 5d78353..5dfa4aa 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
- alias(source.copy(), source.name or source.alias, table=True),
+ alias(source, source.name or source.alias, table=True),
copy=False,
)
.subquery(source.alias, copy=False)
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
deleted file mode 100644
index fae1726..0000000
--- a/sqlglot/optimizer/lower_identities.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from sqlglot import exp
-
-
-def lower_identities(expression):
- """
- Convert all unquoted identifiers to lower case.
-
- Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
-
- Example:
- >>> import sqlglot
- >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
- >>> lower_identities(expression).sql()
- 'SELECT bar.a AS A FROM "Foo".bar'
-
- Args:
- expression (sqlglot.Expression): expression to quote
- Returns:
- sqlglot.Expression: quoted expression
- """
- # We need to leave the output aliases unchanged, so the selects need special handling
- _lower_selects(expression)
-
- # These clauses can reference output aliases and also need special handling
- _lower_order(expression)
- _lower_having(expression)
-
- # We've already handled these args, so don't traverse into them
- traversed = {"expressions", "order", "having"}
-
- if isinstance(expression, exp.Subquery):
- # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
- lower_identities(expression.this)
- traversed |= {"this"}
-
- if isinstance(expression, exp.Union):
- # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
- lower_identities(expression.left)
- lower_identities(expression.right)
- traversed |= {"this", "expression"}
-
- for k, v in expression.iter_expressions():
- if k in traversed:
- continue
- v.transform(_lower, copy=False)
-
- return expression
-
-
-def _lower_selects(expression):
- for e in expression.expressions:
- # Leave output aliases as-is
- e.unalias().transform(_lower, copy=False)
-
-
-def _lower_order(expression):
- order = expression.args.get("order")
-
- if not order:
- return
-
- output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
-
- for ordered in order.expressions:
- # Don't lower references to output aliases
- if not (
- isinstance(ordered.this, exp.Column)
- and not ordered.this.table
- and ordered.this.name in output_aliases
- ):
- ordered.transform(_lower, copy=False)
-
-
-def _lower_having(expression):
- having = expression.args.get("having")
-
- if not having:
- return
-
- # Don't lower references to output aliases
- for agg in having.find_all(exp.AggFunc):
- agg.transform(_lower, copy=False)
-
-
-def _lower(node):
- if isinstance(node, exp.Identifier) and not node.quoted:
- node.set("this", node.this.lower())
- return node
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index c3467b2..f9c9664 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Example:
>>> import sqlglot
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
- 'SELECT x.a FROM x JOIN y'
+ 'SELECT x.a FROM x CROSS JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
- 'SELECT a FROM (SELECT x.a FROM x) JOIN y'
+ 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
@@ -154,7 +154,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
inner_from = inner_scope.expression.args.get("from")
if not inner_from:
return False
- inner_from_table = inner_from.expressions[0].alias_or_name
+ inner_from_table = inner_from.alias_or_name
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
return any(
col.table != inner_from_table
@@ -167,6 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
+ and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
@@ -210,7 +211,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
elif isinstance(source, exp.Table) and source.alias:
source.set("alias", new_alias)
elif isinstance(source, exp.Table):
- source.replace(exp.alias_(source.copy(), new_alias))
+ source.replace(exp.alias_(source, new_alias))
for column in inner_scope.source_columns(conflict):
column.set("table", exp.to_identifier(new_name))
@@ -228,7 +229,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
- new_subquery = inner_scope.expression.args.get("from").expressions[0]
+ new_subquery = inner_scope.expression.args["from"].this
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
@@ -319,7 +320,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
- sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
+ sources = {from_.alias_or_name} if from_ else {}
for join in expression.args["joins"]:
source = join.alias_or_name
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index b013312..1db094e 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,12 +1,12 @@
from __future__ import annotations
import logging
-import typing as t
from sqlglot import exp
from sqlglot.errors import OptimizeError
+from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, uniq_sort
+from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- cache: t.Dict[int, str] = {}
+ generate = cached_generator()
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
+ root = node is expression
+ original = node.copy()
+ node.transform(rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf)
if distance > max_distance:
@@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
)
return expression
- root = node is expression
- original = node.copy()
try:
node = node.replace(
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
)
except OptimizeError as e:
logger.info(e)
@@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance, cache=None):
+def distributive_law(expression, dnf, max_distance, generate):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func, cache)
- return _distribute(b, a, from_func, to_func, cache)
+ return _distribute(a, b, from_func, to_func, generate)
+ return _distribute(b, a, from_func, to_func, generate)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, generate)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(a, b, from_func, to_func, generate)
return expression
-def _distribute(a, b, from_func, to_func, cache):
+def _distribute(a, b, from_func, to_func, generate):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- uniq_sort(flatten(from_func(c, b.left)), cache),
- uniq_sort(flatten(from_func(c, b.right)), cache),
+ uniq_sort(flatten(from_func(c, b.left)), generate),
+ uniq_sort(flatten(from_func(c, b.right)), generate),
copy=False,
),
)
else:
a = to_func(
- uniq_sort(flatten(from_func(a, b.left)), cache),
- uniq_sort(flatten(from_func(a, b.right)), cache),
+ uniq_sort(flatten(from_func(a, b.left)), generate),
+ uniq_sort(flatten(from_func(a, b.right)), generate),
copy=False,
)
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
new file mode 100644
index 0000000..1e5c104
--- /dev/null
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -0,0 +1,36 @@
+from sqlglot import exp
+from sqlglot._typing import E
+from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
+
+
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
+ """
+ Normalize all unquoted identifiers to either lower or upper case, depending on
+ the dialect. This essentially makes those identifiers case-insensitive.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
+ >>> normalize_identifiers(expression).sql()
+ 'SELECT bar.a AS a FROM "Foo".bar'
+
+ Args:
+ expression: The expression to transform.
+ dialect: The dialect to use in order to decide how to normalize identifiers.
+
+ Returns:
+ The transformed expression.
+ """
+ return expression.transform(_normalize, dialect, copy=False)
+
+
+def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
+ if isinstance(node, exp.Identifier) and not node.quoted:
+ node.set(
+ "this",
+ node.this.upper()
+ if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
+ else node.this.lower(),
+ )
+
+ return node
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 8589657..43436cb 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,6 +1,8 @@
from sqlglot import exp
from sqlglot.helper import tsort
+JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
+
def optimize_joins(expression):
"""
@@ -45,7 +47,7 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
- head = from_.expressions[0]
+ head = from_.this
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
@@ -65,6 +67,9 @@ def normalize(expression):
Remove INNER and OUTER from joins as they are optional.
"""
for join in expression.find_all(exp.Join):
+ if not any(join.args.get(k) for k in JOIN_ATTRS):
+ join.set("kind", "CROSS")
+
if join.kind != "CROSS":
join.set("kind", None)
return expression
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index c165ffe..dbe33a2 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -10,36 +10,29 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
-from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
-from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
-from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
-from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
-from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.qualify import qualify
+from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
RULES = (
- lower_identities,
- qualify_tables,
- isolate_table_selects,
- qualify_columns,
+ qualify,
pushdown_projections,
- validate_qualify_columns,
normalize,
unnest_subqueries,
- expand_multi_table_selects,
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_subqueries,
eliminate_joins,
eliminate_ctes,
+ quote_identifiers,
annotate_types,
canonicalize,
simplify,
@@ -54,7 +47,7 @@ def optimize(
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
-):
+) -> exp.Expression:
"""
Rewrite a sqlglot AST into an optimized form.
@@ -72,14 +65,23 @@ def optimize(
dialect: The dialect to parse the sql string.
rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
- Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
- what you're doing!
+ Do not remove `qualify` from the sequence of rules unless you know what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
+
Returns:
- sqlglot.Expression: optimized expression
+ The optimized expression.
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
- possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
+ possible_kwargs = {
+ "db": db,
+ "catalog": catalog,
+ "schema": schema,
+ "dialect": dialect,
+ "isolate_tables": True, # needed for other optimizations to perform well
+ "quote_identifiers": False, # this happens in canonicalize
+ **kwargs,
+ }
+
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
@@ -88,4 +90,5 @@ def optimize(
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
- return expression
+
+ return t.cast(exp.Expression, expression)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index ba5c8b5..96dda33 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -21,26 +21,28 @@ def pushdown_predicates(expression):
sqlglot.Expression: optimized expression
"""
root = build_scope(expression)
- scope_ref_count = root.ref_count()
-
- for scope in reversed(list(root.traverse())):
- select = scope.expression
- where = select.args.get("where")
- if where:
- selected_sources = scope.selected_sources
- # a right join can only push down to itself and not the source FROM table
- for k, (node, source) in selected_sources.items():
- parent = node.find_ancestor(exp.Join, exp.From)
- if isinstance(parent, exp.Join) and parent.side == "RIGHT":
- selected_sources = {k: (node, source)}
- break
- pushdown(where.this, selected_sources, scope_ref_count)
-
- # joins should only pushdown into itself, not to other joins
- # so we limit the selected sources to only itself
- for join in select.args.get("joins") or []:
- name = join.this.alias_or_name
- pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
+
+ if root:
+ scope_ref_count = root.ref_count()
+
+ for scope in reversed(list(root.traverse())):
+ select = scope.expression
+ where = select.args.get("where")
+ if where:
+ selected_sources = scope.selected_sources
+ # a right join can only push down to itself and not the source FROM table
+ for k, (node, source) in selected_sources.items():
+ parent = node.find_ancestor(exp.Join, exp.From)
+ if isinstance(parent, exp.Join) and parent.side == "RIGHT":
+ selected_sources = {k: (node, source)}
+ break
+ pushdown(where.this, selected_sources, scope_ref_count)
+
+ # joins should only pushdown into itself, not to other joins
+ # so we limit the selected sources to only itself
+ for join in select.args.get("joins") or []:
+ name = join.this.alias_or_name
+ pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 2e51117..be3ddb2 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
- if scope.expression.args.get("distinct"):
- # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
+ if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
+ # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
+ # we select from a pivoted source in the parent scope.
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
@@ -105,7 +106,9 @@ def _remove_unused_selections(scope, parent_selections, schema):
for name in sorted(parent_selections):
if name not in names:
- new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
+ new_selections.append(
+ alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
+ )
# If there are no remaining selections, just select a single constant
if not new_selections:
diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py
new file mode 100644
index 0000000..5fdbde8
--- /dev/null
+++ b/sqlglot/optimizer/qualify.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp
+from sqlglot.dialects.dialect import DialectType
+from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
+from sqlglot.optimizer.qualify_columns import (
+ qualify_columns as qualify_columns_func,
+ quote_identifiers as quote_identifiers_func,
+ validate_qualify_columns as validate_qualify_columns_func,
+)
+from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.schema import Schema, ensure_schema
+
+
+def qualify(
+ expression: exp.Expression,
+ dialect: DialectType = None,
+ db: t.Optional[str] = None,
+ catalog: t.Optional[str] = None,
+ schema: t.Optional[dict | Schema] = None,
+ expand_alias_refs: bool = True,
+ infer_schema: t.Optional[bool] = None,
+ isolate_tables: bool = False,
+ qualify_columns: bool = True,
+ validate_qualify_columns: bool = True,
+ quote_identifiers: bool = True,
+ identify: bool = True,
+) -> exp.Expression:
+ """
+ Rewrite sqlglot AST to have normalized and qualified tables and columns.
+
+ This step is necessary for all further SQLGlot optimizations.
+
+ Example:
+ >>> import sqlglot
+ >>> schema = {"tbl": {"col": "INT"}}
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+ >>> qualify(expression, schema=schema).sql()
+ 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"'
+
+ Args:
+ expression: Expression to qualify.
+ db: Default database name for tables.
+ catalog: Default catalog name for tables.
+ schema: Schema to infer column names and types.
+ expand_alias_refs: Whether or not to expand references to aliases.
+ infer_schema: Whether or not to infer the schema if missing.
+ isolate_tables: Whether or not to isolate table selects.
+ qualify_columns: Whether or not to qualify columns.
+ validate_qualify_columns: Whether or not to validate columns.
+ quote_identifiers: Whether or not to run the quote_identifiers step.
+ This step is necessary to ensure correctness for case sensitive queries.
+ But this flag is provided in case this step is performed at a later time.
+ identify: If True, quote all identifiers, else only necessary ones.
+
+ Returns:
+ The qualified expression.
+ """
+ schema = ensure_schema(schema, dialect=dialect)
+ expression = normalize_identifiers(expression, dialect=dialect)
+ expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
+
+ if isolate_tables:
+ expression = isolate_table_selects(expression, schema=schema)
+
+ if qualify_columns:
+ expression = qualify_columns_func(
+ expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
+ )
+
+ if quote_identifiers:
+ expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)
+
+ if validate_qualify_columns:
+ validate_qualify_columns_func(expression)
+
+ return expression
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 6ac39f0..4a31171 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -1,14 +1,23 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
+from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import OptimizeError
-from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
-from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.schema import ensure_schema
+from sqlglot.helper import case_sensitive, seq_get
+from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.schema import Schema, ensure_schema
-def qualify_columns(expression, schema, expand_laterals=True):
+def qualify_columns(
+ expression: exp.Expression,
+ schema: dict | Schema,
+ expand_alias_refs: bool = True,
+ infer_schema: t.Optional[bool] = None,
+) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True):
'SELECT tbl.col AS col FROM tbl'
Args:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
+ expression: expression to qualify
+ schema: Database schema
+ expand_alias_refs: whether or not to expand references to aliases
+ infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)
-
- if not schema.mapping and expand_laterals:
- expression = _expand_laterals(expression)
+ infer_schema = schema.empty if infer_schema is None else infer_schema
for scope in traverse_scope(expression):
- resolver = Resolver(scope, schema)
+ resolver = Resolver(scope, schema, infer_schema=infer_schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)
+
+ if schema.empty and expand_alias_refs:
+ _expand_alias_refs(scope, resolver)
+
_qualify_columns(scope, resolver)
+
+ if not schema.empty and expand_alias_refs:
+ _expand_alias_refs(scope, resolver)
+
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
- _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
- if schema.mapping and expand_laterals:
- expression = _expand_laterals(expression)
-
return expression
@@ -55,9 +68,11 @@ def validate_qualify_columns(expression):
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
- if scope.external_columns and not scope.is_correlated_subquery:
+ if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
- raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
+ raise OptimizeError(
+ f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
+ )
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
@@ -142,52 +157,48 @@ def _expand_using(scope, resolver):
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
- replacement = exp.alias_(replacement, alias=column.name)
+ replacement = alias(replacement, alias=column.name, copy=False)
scope.replace(column, replacement)
return column_tables
-def _expand_alias_refs(scope, resolver):
- selects = {}
-
- # Replace references to select aliases
- def transform(node, source_first=True):
- if isinstance(node, exp.Column) and not node.table:
- table = resolver.get_table(node.name)
+def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
+ expression = scope.expression
- # Source columns get priority over select aliases
- if source_first and table:
- node.set("table", table)
- return node
+ if not isinstance(expression, exp.Select):
+ return
- if not selects:
- for s in scope.selects:
- selects[s.alias_or_name] = s
- select = selects.get(node.name)
+ alias_to_expression: t.Dict[str, exp.Expression] = {}
- if select:
- scope.clear_cache()
- if isinstance(select, exp.Alias):
- select = select.this
- return select.copy()
+ def replace_columns(
+ node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
+ ):
+ if not node:
+ return
- node.set("table", table)
- elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
- exp.replace_children(node, transform, source_first)
+ for column, *_ in walk_in_scope(node):
+ if not isinstance(column, exp.Column):
+ continue
+ table = resolver.get_table(column.name) if resolve_agg and not column.table else None
+ if table and column.find_ancestor(exp.AggFunc):
+ column.set("table", table)
+ elif expand and not column.table and column.name in alias_to_expression:
+ column.replace(alias_to_expression[column.name].copy())
- return node
+ for projection in scope.selects:
+ replace_columns(projection)
- for select in scope.expression.selects:
- transform(select)
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
- for modifier, source_first in (
- ("where", True),
- ("group", True),
- ("having", False),
- ):
- transform(scope.expression.args.get(modifier), source_first=source_first)
+ replace_columns(expression.args.get("where"))
+ replace_columns(expression.args.get("group"))
+ replace_columns(expression.args.get("having"), resolve_agg=True)
+ replace_columns(expression.args.get("qualify"), resolve_agg=True)
+ replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
+ scope.clear_cache()
def _expand_group_by(scope, resolver):
@@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
+ if scope.pivots and not column.find_ancestor(exp.Pivot):
+ # If the column is under the Pivot expression, we need to qualify it
+ # using the name of the pivoted source instead of the pivot's alias
+ column.set("table", exp.to_identifier(scope.pivots[0].alias))
+ continue
+
column_table = resolver.get_table(column_name)
# column_table can be a '' because bigquery unnest has no table alias
@@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
- columns_missing_from_scope = []
-
- # Determine whether each reference in the order by clause is to a column or an alias.
- order = scope.expression.args.get("order")
-
- if order:
- for ordered in order.expressions:
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
-
- # Determine whether each reference in the having clause is to a column or an alias.
- having = scope.expression.args.get("having")
-
- if having:
- for column in having.find_all(exp.Column):
- if (
- not column.table
- and column.find_ancestor(exp.AggFunc)
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
-
- for column in columns_missing_from_scope:
- column_table = resolver.get_table(column.name)
-
- if column_table:
- column.set("table", column_table)
+ for pivot in scope.pivots:
+ for column in pivot.find_all(exp.Column):
+ if not column.table and column.name in resolver.all_columns:
+ column_table = resolver.get_table(column.name)
+ if column_table:
+ column.set("table", column_table)
def _expand_stars(scope, resolver, using_column_tables):
@@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables):
replace_columns = {}
coalesced_columns = set()
+ # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
+ pivot_columns = None
+ pivot_output_columns = None
+ pivot = seq_get(scope.pivots, 0)
+
+ has_pivoted_source = pivot and not pivot.args.get("unpivot")
+ if has_pivoted_source:
+ pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
+
+ pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
+ if not pivot_output_columns:
+ pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
+
for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
@@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
+
columns = resolver.get_source_columns(table, only_visible=True)
if columns and "*" not in columns:
+ if has_pivoted_source:
+ implicit_columns = [col for col in columns if col not in pivot_columns]
+ new_selections.extend(
+ exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
+ for name in implicit_columns + pivot_output_columns
+ )
+ continue
+
table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
@@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables):
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
- exp.alias_(
- exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ alias(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
+ alias=name,
+ copy=False,
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
- column = exp.column(name, table)
- new_selections.append(alias(column, alias_) if alias_ != name else column)
+ column = exp.column(name, table=table)
+ new_selections.append(
+ alias(column, alias_, copy=False) if alias_ != name else column
+ )
else:
return
+
scope.expression.set("expressions", new_selections)
@@ -388,9 +406,6 @@ def _qualify_outputs(scope):
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
- quoted=True
- if isinstance(selection, exp.Column) and selection.this.quoted
- else None,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
@@ -400,6 +415,23 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
+def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
+ """Makes sure all identifiers that need to be quoted are quoted."""
+
+ def _quote(expression: E) -> E:
+ if isinstance(expression, exp.Identifier):
+ name = expression.this
+ expression.set(
+ "quoted",
+ identify
+ or case_sensitive(name, dialect=dialect)
+ or not exp.SAFE_IDENTIFIER_RE.match(name),
+ )
+ return expression
+
+ return expression.transform(_quote, copy=False)
+
+
class Resolver:
"""
Helper for resolving columns.
@@ -407,12 +439,13 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""
- def __init__(self, scope, schema):
+ def __init__(self, scope, schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
- self._unambiguous_columns = None
+ self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
self._all_columns = None
+ self._infer_schema = infer_schema
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
"""
@@ -430,7 +463,7 @@ class Resolver:
table_name = self._unambiguous_columns.get(column_name)
- if not table_name:
+ if not table_name and self._infer_schema:
sources_without_schema = tuple(
source
for source, columns in self._get_all_source_columns().items()
@@ -450,11 +483,9 @@ class Resolver:
node_alias = node.args.get("alias")
if node_alias:
- return node_alias.this
+ return exp.to_identifier(node_alias.this)
- return exp.to_identifier(
- table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
- )
+ return exp.to_identifier(table_name)
@property
def all_columns(self):
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 1b451a6..fcc5f26 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,11 +1,19 @@
import itertools
+import typing as t
from sqlglot import alias, exp
-from sqlglot.helper import csv_reader
+from sqlglot._typing import E
+from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import Schema
-def qualify_tables(expression, db=None, catalog=None, schema=None):
+def qualify_tables(
+ expression: E,
+ db: t.Optional[str] = None,
+ catalog: t.Optional[str] = None,
+ schema: t.Optional[Schema] = None,
+) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
@@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
- expression (sqlglot.Expression): expression to qualify
- db (str): Database name
- catalog (str): Catalog name
+ expression: Expression to qualify
+ db: Database name
+ catalog: Catalog name
schema: A schema to populate
Returns:
- sqlglot.Expression: qualified expression
+ The qualified expression.
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
- sequence = itertools.count()
-
- next_name = lambda: f"_q_{next(sequence)}"
+ next_alias_name = name_sequence("_q_")
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
if not derived_table.args.get("alias"):
- alias_ = f"_q_{next(sequence)}"
+ alias_ = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
+ pivots = derived_table.args.get("pivots")
+ if pivots and not pivots[0].alias:
+ pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
+
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
@@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not source.alias:
source = source.replace(
alias(
- source.copy(),
- name if name else next_name(),
+ source,
+ name or source.name or next_alias_name(),
+ copy=True,
table=True,
)
)
+ pivots = source.args.get("pivots")
+ if pivots and not pivots[0].alias:
+ pivots[0].set(
+ "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
+ )
+
if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
@@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
- table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
+ table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
udtf.set("alias", table_alias)
if not table_alias.name:
- table_alias.set("this", next_name())
+ table_alias.set("this", next_alias_name())
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index e00b3c9..9ffb4d6 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,4 +1,5 @@
import itertools
+import typing as t
from collections import defaultdict
from enum import Enum, auto
@@ -83,6 +84,7 @@ class Scope:
self._columns = None
self._external_columns = None
self._join_hints = None
+ self._pivots = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@@ -261,12 +263,14 @@ class Scope:
self._columns = []
for column in columns + external_columns:
- ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
+ ancestor = column.find_ancestor(
+ exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
+ )
if (
not ancestor
- # Window functions can have an ORDER BY clause
- or not isinstance(ancestor.parent, exp.Select)
or column.table
+ or isinstance(ancestor, exp.Select)
+ or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
@@ -370,6 +374,17 @@ class Scope:
return []
return self._join_hints
+ @property
+ def pivots(self):
+ if not self._pivots:
+ self._pivots = [
+ pivot
+ for node in self.tables + self.derived_tables
+ for pivot in node.args.get("pivots") or []
+ ]
+
+ return self._pivots
+
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
@@ -463,7 +478,7 @@ class Scope:
return scope_ref_count
-def traverse_scope(expression):
+def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
@@ -488,10 +503,12 @@ def traverse_scope(expression):
Returns:
list[Scope]: scope instances
"""
+ if not isinstance(expression, exp.Unionable):
+ return []
return list(_traverse_scope(Scope(expression)))
-def build_scope(expression):
+def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
"""
Build a scope tree.
@@ -500,7 +517,10 @@ def build_scope(expression):
Returns:
Scope: root scope
"""
- return traverse_scope(expression)[-1]
+ scopes = traverse_scope(expression)
+ if scopes:
+ return scopes[-1]
+ return None
def _traverse_scope(scope):
@@ -585,7 +605,7 @@ def _traverse_tables(scope):
expressions = []
from_ = scope.expression.args.get("from")
if from_:
- expressions.extend(from_.expressions)
+ expressions.append(from_.this)
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
@@ -601,8 +621,13 @@ def _traverse_tables(scope):
source_name = expression.alias_or_name
if table_name in scope.sources:
- # This is a reference to a parent source (e.g. a CTE), not an actual table.
- sources[source_name] = scope.sources[table_name]
+ # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
+ # it is pivoted, because then we get back a new table and hence a new source.
+ pivots = expression.args.get("pivots")
+ if pivots:
+ sources[pivots[0].alias] = expression
+ else:
+ sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 0904189..e2772a0 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,9 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.generator import Generator
+from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify="safe")
-
def simplify(expression):
"""
@@ -27,12 +25,12 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
- cache = {}
+ generate = cached_generator()
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node, cache, root)
+ node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
@@ -247,7 +245,7 @@ def remove_compliments(expression, root=True):
return expression
-def uniq_sort(expression, cache=None, root=True):
+def uniq_sort(expression, generate, root=True):
"""
Uniq and sort a connector.
@@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e, cache): e for e in flattened}
+ deduped = {generate(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b):
def simplify_parens(expression):
- if (
- isinstance(expression, exp.Paren)
- and not isinstance(expression.this, exp.Select)
- and (
- not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, exp.Predicate)
- or not isinstance(expression.this, exp.Binary)
- )
+ if not isinstance(expression, exp.Paren):
+ return expression
+
+ this = expression.this
+ parent = expression.parent
+
+ if not isinstance(this, exp.Select) and (
+ not isinstance(parent, (exp.Condition, exp.Binary))
+ or isinstance(this, exp.Predicate)
+ or not isinstance(this, exp.Binary)
+ or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
+ or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
):
return expression.this
return expression
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index a515489..09e3f2a 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -1,6 +1,5 @@
-import itertools
-
from sqlglot import exp
+from sqlglot.helper import name_sequence
from sqlglot.optimizer.scope import ScopeType, traverse_scope
@@ -22,7 +21,7 @@ def unnest_subqueries(expression):
Returns:
sqlglot.Expression: unnested expression
"""
- sequence = itertools.count()
+ next_alias_name = name_sequence("_u_")
for scope in traverse_scope(expression):
select = scope.expression
@@ -30,19 +29,19 @@ def unnest_subqueries(expression):
if not parent:
continue
if scope.external_columns:
- decorrelate(select, parent, scope.external_columns, sequence)
+ decorrelate(select, parent, scope.external_columns, next_alias_name)
elif scope.scope_type == ScopeType.SUBQUERY:
- unnest(select, parent, sequence)
+ unnest(select, parent, next_alias_name)
return expression
-def unnest(select, parent_select, sequence):
+def unnest(select, parent_select, next_alias_name):
if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
- alias = _alias(sequence)
+ alias = next_alias_name()
if not predicate or parent_select is not predicate.parent_select:
return
@@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence):
)
-def decorrelate(select, parent_select, external_columns, sequence):
+def decorrelate(select, parent_select, external_columns, next_alias_name):
where = select.args.get("where")
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return
- table_alias = _alias(sequence)
+ table_alias = next_alias_name()
keys = []
# for all external columns in the where statement, find the relevant predicate
@@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
group_by.append(key)
else:
if key not in key_aliases:
- key_aliases[key] = _alias(sequence)
+ key_aliases[key] = next_alias_name()
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
@@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
)
-def _alias(sequence):
- return f"_u_{next(sequence)}"
-
-
def _replace(expression, condition):
return expression.replace(exp.condition(condition))
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index d8d9f88..e77bb5a 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -6,22 +6,17 @@ from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import (
- apply_index_offset,
- count_params,
- ensure_collection,
- ensure_list,
- seq_get,
-)
+from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
-logger = logging.getLogger("sqlglot")
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
-E = t.TypeVar("E", bound=exp.Expression)
+logger = logging.getLogger("sqlglot")
-def parse_var_map(args: t.Sequence) -> exp.Expression:
+def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
if len(args) == 1 and args[0].is_star:
return exp.StarMap(this=args[0])
@@ -36,7 +31,7 @@ def parse_var_map(args: t.Sequence) -> exp.Expression:
)
-def parse_like(args):
+def parse_like(args: t.List) -> exp.Expression:
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
@@ -65,7 +60,7 @@ class Parser(metaclass=_Parser):
Args:
error_level: the desired error level.
- Default: ErrorLevel.RAISE
+ Default: ErrorLevel.IMMEDIATE
error_message_context: determines the amount of context to capture from a
query string when displaying the error message (in number of characters).
Default: 50.
@@ -118,8 +113,8 @@ class Parser(metaclass=_Parser):
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
- TokenType.STRUCT,
TokenType.NULLABLE,
+ TokenType.STRUCT,
}
TYPE_TOKENS = {
@@ -158,6 +153,7 @@ class Parser(metaclass=_Parser):
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME,
+ TokenType.DATETIME64,
TokenType.DATE,
TokenType.DECIMAL,
TokenType.BIGDECIMAL,
@@ -211,20 +207,18 @@ class Parser(metaclass=_Parser):
TokenType.VAR,
TokenType.ANTI,
TokenType.APPLY,
+ TokenType.ASC,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
- TokenType.BOTH,
- TokenType.BUCKET,
TokenType.CACHE,
- TokenType.CASCADE,
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMENT,
TokenType.COMMIT,
- TokenType.COMPOUND,
TokenType.CONSTRAINT,
TokenType.DEFAULT,
TokenType.DELETE,
+ TokenType.DESC,
TokenType.DESCRIBE,
TokenType.DIV,
TokenType.END,
@@ -233,7 +227,6 @@ class Parser(metaclass=_Parser):
TokenType.FALSE,
TokenType.FIRST,
TokenType.FILTER,
- TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FULL,
TokenType.IF,
@@ -241,41 +234,31 @@ class Parser(metaclass=_Parser):
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.KEEP,
- TokenType.LAZY,
- TokenType.LEADING,
TokenType.LEFT,
- TokenType.LOCAL,
- TokenType.MATERIALIZED,
+ TokenType.LOAD,
TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.OFFSET,
- TokenType.ONLY,
- TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.OVERWRITE,
TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRAGMA,
- TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
- TokenType.SEED,
TokenType.SEMI,
TokenType.SET,
+ TokenType.SETTINGS,
TokenType.SHOW,
- TokenType.SORTKEY,
TokenType.TEMPORARY,
TokenType.TOP,
- TokenType.TRAILING,
TokenType.TRUE,
- TokenType.UNBOUNDED,
TokenType.UNIQUE,
- TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.VOLATILE,
TokenType.WINDOW,
@@ -291,6 +274,7 @@ class Parser(metaclass=_Parser):
TokenType.APPLY,
TokenType.FULL,
TokenType.LEFT,
+ TokenType.LOCK,
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.RIGHT,
@@ -301,7 +285,7 @@ class Parser(metaclass=_Parser):
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
- TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
+ TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"}
FUNC_TOKENS = {
TokenType.COMMAND,
@@ -322,6 +306,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
+ TokenType.RANGE,
TokenType.REPLACE,
TokenType.ROW,
TokenType.UNNEST,
@@ -455,31 +440,31 @@ class Parser(metaclass=_Parser):
}
EXPRESSION_PARSERS = {
+ exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
exp.Column: lambda self: self._parse_column(),
+ exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(),
+ exp.Expression: lambda self: self._parse_statement(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
+ exp.Having: lambda self: self._parse_having(),
exp.Identifier: lambda self: self._parse_id_var(),
- exp.Lateral: lambda self: self._parse_lateral(),
exp.Join: lambda self: self._parse_join(),
- exp.Order: lambda self: self._parse_order(),
- exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
- exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
exp.Lambda: lambda self: self._parse_lambda(),
+ exp.Lateral: lambda self: self._parse_lateral(),
exp.Limit: lambda self: self._parse_limit(),
exp.Offset: lambda self: self._parse_offset(),
- exp.TableAlias: lambda self: self._parse_table_alias(),
- exp.Table: lambda self: self._parse_table(),
- exp.Condition: lambda self: self._parse_conjunction(),
- exp.Expression: lambda self: self._parse_statement(),
- exp.Properties: lambda self: self._parse_properties(),
- exp.Where: lambda self: self._parse_where(),
+ exp.Order: lambda self: self._parse_order(),
exp.Ordered: lambda self: self._parse_ordered(),
- exp.Having: lambda self: self._parse_having(),
- exp.With: lambda self: self._parse_with(),
- exp.Window: lambda self: self._parse_named_window(),
+ exp.Properties: lambda self: self._parse_properties(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
+ exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
+ exp.Table: lambda self: self._parse_table_parts(),
+ exp.TableAlias: lambda self: self._parse_table_alias(),
+ exp.Where: lambda self: self._parse_where(),
+ exp.Window: lambda self: self._parse_named_window(),
+ exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@@ -495,9 +480,13 @@ class Parser(metaclass=_Parser):
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
+ TokenType.FROM: lambda self: exp.select("*").from_(
+ t.cast(exp.From, self._parse_from(skip_from_token=True))
+ ),
TokenType.INSERT: lambda self: self._parse_insert(),
- TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
+ TokenType.LOAD: lambda self: self._parse_load(),
TokenType.MERGE: lambda self: self._parse_merge(),
+ TokenType.PIVOT: lambda self: self._parse_simplified_pivot(),
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
@@ -536,7 +525,10 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
- TokenType.NATIONAL: lambda self, token: self._parse_national(token),
+ TokenType.NATIONAL_STRING: lambda self, token: self.expression(
+ exp.National, this=token.text
+ ),
+ TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@@ -551,91 +543,76 @@ class Parser(metaclass=_Parser):
RANGE_PARSERS = {
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.GLOB: binary_range_parser(exp.Glob),
- TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
+ TokenType.ILIKE: binary_range_parser(exp.ILike),
TokenType.IN: lambda self, this: self._parse_in(this),
+ TokenType.IRLIKE: binary_range_parser(exp.RegexpILike),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: binary_range_parser(exp.Like),
- TokenType.ILIKE: binary_range_parser(exp.ILike),
- TokenType.IRLIKE: binary_range_parser(exp.RegexpILike),
+ TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
TokenType.RLIKE: binary_range_parser(exp.RegexpLike),
TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo),
}
- PROPERTY_PARSERS = {
- "AFTER": lambda self: self._parse_afterjournal(
- no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
- ),
+ PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
- "BEFORE": lambda self: self._parse_journal(
- no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
- ),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
- "CLUSTER BY": lambda self: self.expression(
- exp.Cluster, expressions=self._parse_csv(self._parse_ordered)
- ),
+ "CLUSTER": lambda self: self._parse_cluster(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
- "DATABLOCKSIZE": lambda self: self._parse_datablocksize(
- default=self._prev.text.upper() == "DEFAULT"
- ),
+ "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
+ "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
"EXTERNAL": lambda self: self.expression(exp.ExternalProperty),
- "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
+ "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
- "GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
- "JOURNAL": lambda self: self._parse_journal(
- no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
- ),
+ "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"LIKE": lambda self: self._parse_create_like(),
- "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"LOCK": lambda self: self._parse_locking(),
"LOCKING": lambda self: self._parse_locking(),
- "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
+ "LOG": lambda self, **kwargs: self._parse_log(**kwargs),
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
- "MAX": lambda self: self._parse_datablocksize(),
- "MAXIMUM": lambda self: self._parse_datablocksize(),
- "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio(
- no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT"
- ),
- "MIN": lambda self: self._parse_datablocksize(),
- "MINIMUM": lambda self: self._parse_datablocksize(),
+ "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs),
"MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
- "NO": lambda self: self._parse_noprimaryindex(),
- "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False),
- "ON": lambda self: self._parse_oncommit(),
+ "NO": lambda self: self._parse_no_property(),
+ "ON": lambda self: self._parse_on_property(),
+ "ORDER BY": lambda self: self._parse_order(skip_order_token=True),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
+ "PRIMARY KEY": lambda self: self._parse_primary_key(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
+ "SETTINGS": lambda self: self.expression(
+ exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
+ ),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
"STORED": lambda self: self._parse_stored(),
- "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
- "TEMP": lambda self: self._parse_temporary(global_=False),
- "TEMPORARY": lambda self: self._parse_temporary(global_=False),
+ "TEMP": lambda self: self.expression(exp.TemporaryProperty),
+ "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
- "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
+ "TTL": lambda self: self._parse_ttl(),
+ "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"VOLATILE": lambda self: self._parse_volatile_property(),
"WITH": lambda self: self._parse_with_property(),
}
@@ -679,6 +656,7 @@ class Parser(metaclass=_Parser):
"TITLE": lambda self: self.expression(
exp.TitleColumnConstraint, this=self._parse_var_or_string()
),
+ "TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]),
"UNIQUE": lambda self: self._parse_unique(),
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
}
@@ -704,6 +682,8 @@ class Parser(metaclass=_Parser):
),
}
+ FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
+
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
@@ -712,7 +692,9 @@ class Parser(metaclass=_Parser):
"JSON_OBJECT": lambda self: self._parse_json_object(),
"LOG": lambda self: self._parse_logarithm(),
"MATCH": lambda self: self._parse_match_against(),
+ "OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
+ "SAFE_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
@@ -721,19 +703,18 @@ class Parser(metaclass=_Parser):
}
QUERY_MODIFIER_PARSERS = {
+ "joins": lambda self: list(iter(self._parse_join, None)),
+ "laterals": lambda self: list(iter(self._parse_lateral, None)),
"match": lambda self: self._parse_match_recognize(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"windows": lambda self: self._parse_window_clause(),
- "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
- "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
- "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
"order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(),
- "lock": lambda self: self._parse_lock(),
+ "locks": lambda self: self._parse_locks(),
"sample": lambda self: self._parse_table_sample(as_modifier=True),
}
@@ -763,8 +744,11 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
+ CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
+
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
+ WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@@ -772,8 +756,8 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
- QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
PREFIXED_PIVOT_COLUMNS = False
+ IDENTIFY_PIVOT_STRINGS = False
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
@@ -875,7 +859,7 @@ class Parser(metaclass=_Parser):
e.errors[0]["into_expression"] = expression_type
errors.append(e)
raise ParseError(
- f"Failed to parse into {expression_types}",
+ f"Failed to parse '{sql or raw_tokens}' into {expression_types}",
errors=merge_errors(errors),
) from errors[-1]
@@ -933,7 +917,7 @@ class Parser(metaclass=_Parser):
"""
token = token or self._curr or self._prev or Token.string("")
start = token.start
- end = token.end
+ end = token.end + 1
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
@@ -996,7 +980,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
- return self.sql[start.start : end.end]
+ return self.sql[start.start : end.end + 1]
def _advance(self, times: int = 1) -> None:
self._index += times
@@ -1042,6 +1026,44 @@ class Parser(metaclass=_Parser):
exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
)
+ # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
+ def _parse_ttl(self) -> exp.Expression:
+ def _parse_ttl_action() -> t.Optional[exp.Expression]:
+ this = self._parse_bitwise()
+
+ if self._match_text_seq("DELETE"):
+ return self.expression(exp.MergeTreeTTLAction, this=this, delete=True)
+ if self._match_text_seq("RECOMPRESS"):
+ return self.expression(
+ exp.MergeTreeTTLAction, this=this, recompress=self._parse_bitwise()
+ )
+ if self._match_text_seq("TO", "DISK"):
+ return self.expression(
+ exp.MergeTreeTTLAction, this=this, to_disk=self._parse_string()
+ )
+ if self._match_text_seq("TO", "VOLUME"):
+ return self.expression(
+ exp.MergeTreeTTLAction, this=this, to_volume=self._parse_string()
+ )
+
+ return this
+
+ expressions = self._parse_csv(_parse_ttl_action)
+ where = self._parse_where()
+ group = self._parse_group()
+
+ aggregates = None
+ if group and self._match(TokenType.SET):
+ aggregates = self._parse_csv(self._parse_set_item)
+
+ return self.expression(
+ exp.MergeTreeTTL,
+ expressions=expressions,
+ where=where,
+ group=group,
+ aggregates=aggregates,
+ )
+
def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -1054,14 +1076,12 @@ class Parser(metaclass=_Parser):
expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
-
- self._parse_query_modifiers(expression)
- return expression
+ return self._parse_query_modifiers(expression)
def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
- materialized = self._match(TokenType.MATERIALIZED)
+ materialized = self._match_text_seq("MATERIALIZED")
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
return self._parse_as_command(start)
@@ -1073,7 +1093,7 @@ class Parser(metaclass=_Parser):
kind=kind,
temporary=temporary,
materialized=materialized,
- cascade=self._match(TokenType.CASCADE),
+ cascade=self._match_text_seq("CASCADE"),
constraints=self._match_text_seq("CONSTRAINTS"),
purge=self._match_text_seq("PURGE"),
)
@@ -1111,6 +1131,7 @@ class Parser(metaclass=_Parser):
indexes = None
no_schema_binding = None
begin = None
+ clone = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
@@ -1128,7 +1149,7 @@ class Parser(metaclass=_Parser):
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
- this = self._parse_index()
+ this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)
@@ -1166,33 +1187,40 @@ class Parser(metaclass=_Parser):
expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
- # exp.Properties.Location.POST_EXPRESSION
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
-
indexes = []
while True:
- index = self._parse_create_table_index()
+ index = self._parse_index()
- # exp.Properties.Location.POST_INDEX
- if self._match(TokenType.PARTITION_BY, advance=False):
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
+ temp_properties = self._parse_properties()
+ if properties and temp_properties:
+ properties.expressions.extend(temp_properties.expressions)
+ elif temp_properties:
+ properties = temp_properties
if not index:
break
else:
+ self._match(TokenType.COMMA)
indexes.append(index)
elif create_token.token_type == TokenType.VIEW:
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
+ if self._match_text_seq("CLONE"):
+ clone = self._parse_table(schema=True)
+ when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
+ clone_kind = (
+ self._match(TokenType.L_PAREN)
+ and self._match_texts(self.CLONE_KINDS)
+ and self._prev.text.upper()
+ )
+ clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
+ self._match(TokenType.R_PAREN)
+ clone = self.expression(
+ exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression
+ )
+
return self.expression(
exp.Create,
this=this,
@@ -1205,18 +1233,31 @@ class Parser(metaclass=_Parser):
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
+ clone=clone,
)
def _parse_property_before(self) -> t.Optional[exp.Expression]:
+ # only used for teradata currently
self._match(TokenType.COMMA)
- # parsers look to _prev for no/dual/default, so need to consume first
- self._match_text_seq("NO")
- self._match_text_seq("DUAL")
- self._match_text_seq("DEFAULT")
+ kwargs = {
+ "no": self._match_text_seq("NO"),
+ "dual": self._match_text_seq("DUAL"),
+ "before": self._match_text_seq("BEFORE"),
+ "default": self._match_text_seq("DEFAULT"),
+ "local": (self._match_text_seq("LOCAL") and "LOCAL")
+ or (self._match_text_seq("NOT", "LOCAL") and "NOT LOCAL"),
+ "after": self._match_text_seq("AFTER"),
+ "minimum": self._match_texts(("MIN", "MINIMUM")),
+ "maximum": self._match_texts(("MAX", "MAXIMUM")),
+ }
- if self.PROPERTY_PARSERS.get(self._curr.text.upper()):
- return self.PROPERTY_PARSERS[self._curr.text.upper()](self)
+ if self._match_texts(self.PROPERTY_PARSERS):
+ parser = self.PROPERTY_PARSERS[self._prev.text.upper()]
+ try:
+ return parser(self, **{k: v for k, v in kwargs.items() if v})
+ except TypeError:
+ self.raise_error(f"Cannot parse property '{self._prev.text}'")
return None
@@ -1227,7 +1268,7 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(default=True)
- if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
+ if self._match_text_seq("COMPOUND", "SORTKEY"):
return self._parse_sortkey(compound=True)
if self._match_text_seq("SQL", "SECURITY"):
@@ -1262,23 +1303,20 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
- return self.expression(
- exp_class,
- this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
- )
+ return self.expression(exp_class, this=self._parse_field())
- def _parse_properties(self, before=None) -> t.Optional[exp.Expression]:
+ def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]:
properties = []
while True:
if before:
- identified_property = self._parse_property_before()
+ prop = self._parse_property_before()
else:
- identified_property = self._parse_property()
+ prop = self._parse_property()
- if not identified_property:
+ if not prop:
break
- for p in ensure_list(identified_property):
+ for p in ensure_list(prop):
properties.append(p)
if properties:
@@ -1286,8 +1324,7 @@ class Parser(metaclass=_Parser):
return None
- def _parse_fallback(self, no=False) -> exp.Expression:
- self._match_text_seq("FALLBACK")
+ def _parse_fallback(self, no: bool = False) -> exp.Expression:
return self.expression(
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
@@ -1345,23 +1382,13 @@ class Parser(metaclass=_Parser):
self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
- def _parse_log(self, no=False) -> exp.Expression:
- self._match_text_seq("LOG")
+ def _parse_log(self, no: bool = False) -> exp.Expression:
return self.expression(exp.LogProperty, no=no)
- def _parse_journal(self, no=False, dual=False) -> exp.Expression:
- before = self._match_text_seq("BEFORE")
- self._match_text_seq("JOURNAL")
- return self.expression(exp.JournalProperty, no=no, dual=dual, before=before)
-
- def _parse_afterjournal(self, no=False, dual=False, local=None) -> exp.Expression:
- self._match_text_seq("NOT")
- self._match_text_seq("LOCAL")
- self._match_text_seq("AFTER", "JOURNAL")
- return self.expression(exp.AfterJournalProperty, no=no, dual=dual, local=local)
+ def _parse_journal(self, **kwargs) -> exp.Expression:
+ return self.expression(exp.JournalProperty, **kwargs)
def _parse_checksum(self) -> exp.Expression:
- self._match_text_seq("CHECKSUM")
self._match(TokenType.EQ)
on = None
@@ -1377,49 +1404,55 @@ class Parser(metaclass=_Parser):
default=default,
)
+ def _parse_cluster(self) -> t.Optional[exp.Expression]:
+ if not self._match_text_seq("BY"):
+ self._retreat(self._index - 1)
+ return None
+ return self.expression(
+ exp.Cluster,
+ expressions=self._parse_csv(self._parse_ordered),
+ )
+
def _parse_freespace(self) -> exp.Expression:
- self._match_text_seq("FREESPACE")
self._match(TokenType.EQ)
return self.expression(
exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT)
)
- def _parse_mergeblockratio(self, no=False, default=False) -> exp.Expression:
- self._match_text_seq("MERGEBLOCKRATIO")
+ def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression:
if self._match(TokenType.EQ):
return self.expression(
exp.MergeBlockRatioProperty,
this=self._parse_number(),
percent=self._match(TokenType.PERCENT),
)
- else:
- return self.expression(
- exp.MergeBlockRatioProperty,
- no=no,
- default=default,
- )
+ return self.expression(
+ exp.MergeBlockRatioProperty,
+ no=no,
+ default=default,
+ )
- def _parse_datablocksize(self, default=None) -> exp.Expression:
- if default:
- self._match_text_seq("DATABLOCKSIZE")
- return self.expression(exp.DataBlocksizeProperty, default=True)
- elif self._match_texts(("MIN", "MINIMUM")):
- self._match_text_seq("DATABLOCKSIZE")
- return self.expression(exp.DataBlocksizeProperty, min=True)
- elif self._match_texts(("MAX", "MAXIMUM")):
- self._match_text_seq("DATABLOCKSIZE")
- return self.expression(exp.DataBlocksizeProperty, min=False)
-
- self._match_text_seq("DATABLOCKSIZE")
+ def _parse_datablocksize(
+ self,
+ default: t.Optional[bool] = None,
+ minimum: t.Optional[bool] = None,
+ maximum: t.Optional[bool] = None,
+ ) -> exp.Expression:
self._match(TokenType.EQ)
size = self._parse_number()
units = None
if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")):
units = self._prev.text
- return self.expression(exp.DataBlocksizeProperty, size=size, units=units)
+ return self.expression(
+ exp.DataBlocksizeProperty,
+ size=size,
+ units=units,
+ default=default,
+ minimum=minimum,
+ maximum=maximum,
+ )
def _parse_blockcompression(self) -> exp.Expression:
- self._match_text_seq("BLOCKCOMPRESSION")
self._match(TokenType.EQ)
always = self._match_text_seq("ALWAYS")
manual = self._match_text_seq("MANUAL")
@@ -1516,7 +1549,7 @@ class Parser(metaclass=_Parser):
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
- def _parse_withdata(self, no=False) -> exp.Expression:
+ def _parse_withdata(self, no: bool = False) -> exp.Expression:
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
@@ -1526,13 +1559,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
- def _parse_noprimaryindex(self) -> exp.Expression:
- self._match_text_seq("PRIMARY", "INDEX")
- return exp.NoPrimaryIndexProperty()
+ def _parse_no_property(self) -> t.Optional[exp.Property]:
+ if self._match_text_seq("PRIMARY", "INDEX"):
+ return exp.NoPrimaryIndexProperty()
+ return None
- def _parse_oncommit(self) -> exp.Expression:
- self._match_text_seq("COMMIT", "PRESERVE", "ROWS")
- return exp.OnCommitProperty()
+ def _parse_on_property(self) -> t.Optional[exp.Property]:
+ if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
+ return exp.OnCommitProperty()
+ elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
+ return exp.OnCommitProperty(delete=True)
+ return None
def _parse_distkey(self) -> exp.Expression:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@@ -1587,10 +1624,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
- def _parse_temporary(self, global_=False) -> exp.Expression:
- self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL"
- return self.expression(exp.TemporaryProperty, global_=global_)
-
def _parse_describe(self) -> exp.Expression:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
@@ -1599,7 +1632,7 @@ class Parser(metaclass=_Parser):
def _parse_insert(self) -> exp.Expression:
overwrite = self._match(TokenType.OVERWRITE)
- local = self._match(TokenType.LOCAL)
+ local = self._match_text_seq("LOCAL")
alternative = None
if self._match_text_seq("DIRECTORY"):
@@ -1700,23 +1733,25 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore
- def _parse_load_data(self) -> exp.Expression:
- local = self._match(TokenType.LOCAL)
- self._match_text_seq("INPATH")
- inpath = self._parse_string()
- overwrite = self._match(TokenType.OVERWRITE)
- self._match_pair(TokenType.INTO, TokenType.TABLE)
+ def _parse_load(self) -> exp.Expression:
+ if self._match_text_seq("DATA"):
+ local = self._match_text_seq("LOCAL")
+ self._match_text_seq("INPATH")
+ inpath = self._parse_string()
+ overwrite = self._match(TokenType.OVERWRITE)
+ self._match_pair(TokenType.INTO, TokenType.TABLE)
- return self.expression(
- exp.LoadData,
- this=self._parse_table(schema=True),
- local=local,
- overwrite=overwrite,
- inpath=inpath,
- partition=self._parse_partition(),
- input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
- serde=self._match_text_seq("SERDE") and self._parse_string(),
- )
+ return self.expression(
+ exp.LoadData,
+ this=self._parse_table(schema=True),
+ local=local,
+ overwrite=overwrite,
+ inpath=inpath,
+ partition=self._parse_partition(),
+ input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
+ serde=self._match_text_seq("SERDE") and self._parse_string(),
+ )
+ return self._parse_as_command(self._prev)
def _parse_delete(self) -> exp.Expression:
self._match(TokenType.FROM)
@@ -1735,7 +1770,7 @@ class Parser(metaclass=_Parser):
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
- "from": self._parse_from(),
+ "from": self._parse_from(modifiers=True),
"where": self._parse_where(),
"returning": self._parse_returning(),
},
@@ -1752,12 +1787,12 @@ class Parser(metaclass=_Parser):
)
def _parse_cache(self) -> exp.Expression:
- lazy = self._match(TokenType.LAZY)
+ lazy = self._match_text_seq("LAZY")
self._match(TokenType.TABLE)
table = self._parse_table(schema=True)
options = []
- if self._match(TokenType.OPTIONS):
+ if self._match_text_seq("OPTIONS"):
self._match_l_paren()
k = self._parse_string()
self._match(TokenType.EQ)
@@ -1851,11 +1886,10 @@ class Parser(metaclass=_Parser):
if from_:
this.set("from", from_)
- self._parse_query_modifiers(this)
+ this = self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
- self._parse_query_modifiers(this)
- this = self._parse_set_operations(this)
+ this = self._parse_set_operations(self._parse_query_modifiers(this))
self._match_r_paren()
# early return so that subquery unions aren't parsed again
@@ -1868,6 +1902,10 @@ class Parser(metaclass=_Parser):
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
+ elif self._match(TokenType.PIVOT):
+ this = self._parse_simplified_pivot()
+ elif self._match(TokenType.FROM):
+ this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True)))
else:
this = None
@@ -1929,7 +1967,9 @@ class Parser(metaclass=_Parser):
def _parse_subquery(
self, this: t.Optional[exp.Expression], parse_alias: bool = True
- ) -> exp.Expression:
+ ) -> t.Optional[exp.Expression]:
+ if not this:
+ return None
return self.expression(
exp.Subquery,
this=this,
@@ -1937,35 +1977,16 @@ class Parser(metaclass=_Parser):
alias=self._parse_table_alias() if parse_alias else None,
)
- def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None:
- if not isinstance(this, self.MODIFIABLES):
- return
-
- table = isinstance(this, exp.Table)
-
- while True:
- join = self._parse_join()
- if join:
- this.append("joins", join)
-
- lateral = None
- if not join:
- lateral = self._parse_lateral()
- if lateral:
- this.append("laterals", lateral)
-
- comma = None if table else self._match(TokenType.COMMA)
- if comma:
- this.args["from"].append("expressions", self._parse_table())
-
- if not (lateral or join or comma):
- break
-
- for key, parser in self.QUERY_MODIFIER_PARSERS.items():
- expression = parser(self)
+ def _parse_query_modifiers(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ if isinstance(this, self.MODIFIABLES):
+ for key, parser in self.QUERY_MODIFIER_PARSERS.items():
+ expression = parser(self)
- if expression:
- this.set(key, expression)
+ if expression:
+ this.set(key, expression)
+ return this
def _parse_hint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.HINT):
@@ -1981,19 +2002,26 @@ class Parser(metaclass=_Parser):
return None
temp = self._match(TokenType.TEMPORARY)
- unlogged = self._match(TokenType.UNLOGGED)
+ unlogged = self._match_text_seq("UNLOGGED")
self._match(TokenType.TABLE)
return self.expression(
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
- def _parse_from(self) -> t.Optional[exp.Expression]:
- if not self._match(TokenType.FROM):
+ def _parse_from(
+ self, modifiers: bool = False, skip_from_token: bool = False
+ ) -> t.Optional[exp.From]:
+ if not skip_from_token and not self._match(TokenType.FROM):
return None
+ comments = self._prev_comments
+ this = self._parse_table()
+
return self.expression(
- exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
+ exp.From,
+ comments=comments,
+ this=self._parse_query_modifiers(this) if modifiers else this,
)
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
@@ -2136,6 +2164,9 @@ class Parser(metaclass=_Parser):
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ if self._match(TokenType.COMMA):
+ return self.expression(exp.Join, this=self._parse_table())
+
index = self._index
natural, side, kind = self._parse_join_side_and_kind()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
@@ -2176,55 +2207,66 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Join, **kwargs) # type: ignore
- def _parse_index(self) -> exp.Expression:
- index = self._parse_id_var()
- self._match(TokenType.ON)
- self._match(TokenType.TABLE) # hive
+ def _parse_index(
+ self,
+ index: t.Optional[exp.Expression] = None,
+ ) -> t.Optional[exp.Expression]:
+ if index:
+ unique = None
+ primary = None
+ amp = None
- return self.expression(
- exp.Index,
- this=index,
- table=self.expression(exp.Table, this=self._parse_id_var()),
- columns=self._parse_expression(),
- )
+ self._match(TokenType.ON)
+ self._match(TokenType.TABLE) # hive
+ table = self._parse_table_parts(schema=True)
+ else:
+ unique = self._match(TokenType.UNIQUE)
+ primary = self._match_text_seq("PRIMARY")
+ amp = self._match_text_seq("AMP")
+ if not self._match(TokenType.INDEX):
+ return None
+ index = self._parse_id_var()
+ table = None
- def _parse_create_table_index(self) -> t.Optional[exp.Expression]:
- unique = self._match(TokenType.UNIQUE)
- primary = self._match_text_seq("PRIMARY")
- amp = self._match_text_seq("AMP")
- if not self._match(TokenType.INDEX):
- return None
- index = self._parse_id_var()
- columns = None
if self._match(TokenType.L_PAREN, advance=False):
- columns = self._parse_wrapped_csv(self._parse_column)
+ columns = self._parse_wrapped_csv(self._parse_ordered)
+ else:
+ columns = None
+
return self.expression(
exp.Index,
this=index,
+ table=table,
columns=columns,
unique=unique,
primary=primary,
amp=amp,
+ partition_by=self._parse_partition_by(),
)
- def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
- catalog = None
- db = None
-
- table = (
+ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
+ return (
(not schema and self._parse_function())
or self._parse_id_var(any_token=False)
or self._parse_string_as_identifier()
+ or self._parse_placeholder()
)
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ catalog = None
+ db = None
+ table = self._parse_table_part(schema=schema)
+
while self._match(TokenType.DOT):
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
- table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
+ table = self.expression(
+ exp.Dot, this=table, expression=self._parse_table_part(schema=schema)
+ )
else:
catalog = db
db = table
- table = self._parse_id_var()
+ table = self._parse_table_part(schema=schema)
if not table:
self.raise_error(f"Expected table name but got {self._curr}")
@@ -2237,28 +2279,24 @@ class Parser(metaclass=_Parser):
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
-
if lateral:
return lateral
unnest = self._parse_unnest()
-
if unnest:
return unnest
values = self._parse_derived_table_values()
-
if values:
return values
subquery = self._parse_select(table=True)
-
if subquery:
if not subquery.args.get("pivots"):
subquery.set("pivots", self._parse_pivots())
return subquery
- this = self._parse_table_parts(schema=schema)
+ this: exp.Expression = self._parse_table_parts(schema=schema)
if schema:
return self._parse_schema(this=this)
@@ -2267,7 +2305,6 @@ class Parser(metaclass=_Parser):
table_sample = self._parse_table_sample()
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
-
if alias:
this.set("alias", alias)
@@ -2352,9 +2389,9 @@ class Parser(metaclass=_Parser):
num = self._parse_number()
- if self._match(TokenType.BUCKET):
+ if self._match_text_seq("BUCKET"):
bucket_numerator = self._parse_number()
- self._match(TokenType.OUT_OF)
+ self._match_text_seq("OUT", "OF")
bucket_denominator = bucket_denominator = self._parse_number()
self._match(TokenType.ON)
bucket_field = self._parse_field()
@@ -2390,6 +2427,22 @@ class Parser(metaclass=_Parser):
def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))
+ # https://duckdb.org/docs/sql/statements/pivot
+ def _parse_simplified_pivot(self) -> exp.Pivot:
+ def _parse_on() -> t.Optional[exp.Expression]:
+ this = self._parse_bitwise()
+ return self._parse_in(this) if self._match(TokenType.IN) else this
+
+ this = self._parse_table()
+ expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on)
+ using = self._match(TokenType.USING) and self._parse_csv(
+ lambda: self._parse_alias(self._parse_function())
+ )
+ group = self._parse_group()
+ return self.expression(
+ exp.Pivot, this=this, expressions=expressions, using=using, group=group
+ )
+
def _parse_pivot(self) -> t.Optional[exp.Expression]:
index = self._index
@@ -2423,7 +2476,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.IN):
self.raise_error("Expecting IN")
- field = self._parse_in(value)
+ field = self._parse_in(value, alias=True)
self._match_r_paren()
@@ -2436,21 +2489,22 @@ class Parser(metaclass=_Parser):
names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
columns: t.List[exp.Expression] = []
- for col in pivot.args["field"].expressions:
+ for fld in pivot.args["field"].expressions:
+ field_name = fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name
for name in names:
if self.PREFIXED_PIVOT_COLUMNS:
- name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
+ name = f"{name}_{field_name}" if name else field_name
else:
- name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
+ name = f"{field_name}_{name}" if name else field_name
- columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
+ columns.append(exp.to_identifier(name))
pivot.set("columns", columns)
return pivot
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- return [agg.alias for agg in pivot_columns]
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
+ return [agg.alias for agg in aggregations]
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
@@ -2477,6 +2531,7 @@ class Parser(metaclass=_Parser):
rollup = None
cube = None
+ totals = None
with_ = self._match(TokenType.WITH)
if self._match(TokenType.ROLLUP):
@@ -2487,7 +2542,11 @@ class Parser(metaclass=_Parser):
cube = with_ or self._parse_wrapped_csv(self._parse_column)
elements["cube"].extend(ensure_list(cube))
- if not (expressions or grouping_sets or rollup or cube):
+ if self._match_text_seq("TOTALS"):
+ totals = True
+ elements["totals"] = True # type: ignore
+
+ if not (grouping_sets or rollup or cube or totals):
break
return self.expression(exp.Group, **elements) # type: ignore
@@ -2527,9 +2586,9 @@ class Parser(metaclass=_Parser):
)
def _parse_sort(
- self, token_type: TokenType, exp_class: t.Type[exp.Expression]
+ self, exp_class: t.Type[exp.Expression], *texts: str
) -> t.Optional[exp.Expression]:
- if not self._match(token_type):
+ if not self._match_text_seq(*texts):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
@@ -2537,8 +2596,8 @@ class Parser(metaclass=_Parser):
this = self._parse_conjunction()
self._match(TokenType.ASC)
is_desc = self._match(TokenType.DESC)
- is_nulls_first = self._match(TokenType.NULLS_FIRST)
- is_nulls_last = self._match(TokenType.NULLS_LAST)
+ is_nulls_first = self._match_text_seq("NULLS", "FIRST")
+ is_nulls_last = self._match_text_seq("NULLS", "LAST")
desc = is_desc or False
asc = not desc
nulls_first = is_nulls_first or False
@@ -2578,7 +2637,7 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
- only = self._match(TokenType.ONLY)
+ only = self._match_text_seq("ONLY")
with_ties = self._match_text_seq("WITH", "TIES")
if only and with_ties:
@@ -2602,13 +2661,37 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
- def _parse_lock(self) -> t.Optional[exp.Expression]:
- if self._match_text_seq("FOR", "UPDATE"):
- return self.expression(exp.Lock, update=True)
- if self._match_text_seq("FOR", "SHARE"):
- return self.expression(exp.Lock, update=False)
+ def _parse_locks(self) -> t.List[exp.Expression]:
+ # Lists are invariant, so we need to use a type hint here
+ locks: t.List[exp.Expression] = []
- return None
+ while True:
+ if self._match_text_seq("FOR", "UPDATE"):
+ update = True
+ elif self._match_text_seq("FOR", "SHARE") or self._match_text_seq(
+ "LOCK", "IN", "SHARE", "MODE"
+ ):
+ update = False
+ else:
+ break
+
+ expressions = None
+ if self._match_text_seq("OF"):
+ expressions = self._parse_csv(lambda: self._parse_table(schema=True))
+
+ wait: t.Optional[bool | exp.Expression] = None
+ if self._match_text_seq("NOWAIT"):
+ wait = True
+ elif self._match_text_seq("WAIT"):
+ wait = self._parse_primary()
+ elif self._match_text_seq("SKIP", "LOCKED"):
+ wait = False
+
+ locks.append(
+ self.expression(exp.Lock, update=update, expressions=expressions, wait=wait)
+ )
+
+ return locks
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
@@ -2672,7 +2755,7 @@ class Parser(metaclass=_Parser):
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
index = self._index - 1
negate = self._match(TokenType.NOT)
- if self._match(TokenType.DISTINCT_FROM):
+ if self._match_text_seq("DISTINCT", "FROM"):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
@@ -2684,12 +2767,12 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this
- def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
+ def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
elif self._match(TokenType.L_PAREN):
- expressions = self._parse_csv(self._parse_select_or_expression)
+ expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0])
@@ -2722,15 +2805,19 @@ class Parser(metaclass=_Parser):
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
- if this and isinstance(this, exp.Literal):
- if this.is_number:
- this = exp.Literal.string(this.name)
-
- # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year'
+ if this and this.is_number:
+ this = exp.Literal.string(this.name)
+ elif this and this.is_string:
parts = this.name.split()
- if not unit and len(parts) <= 2:
- this = exp.Literal.string(seq_get(parts, 0))
- unit = self.expression(exp.Var, this=seq_get(parts, 1))
+
+ if len(parts) == 2:
+ if unit:
+ # this is not actually a unit, it's something else
+ unit = None
+ self._retreat(self._index - 1)
+ else:
+ this = exp.Literal.string(parts[0])
+ unit = self.expression(exp.Var, this=parts[1])
return self.expression(exp.Interval, this=this, unit=unit)
@@ -2783,13 +2870,22 @@ class Parser(metaclass=_Parser):
if parser:
return parser(self, this, data_type)
return self.expression(exp.Cast, this=this, to=data_type)
- if not data_type.args.get("expressions"):
+ if not data_type.expressions:
self._retreat(index)
return self._parse_column()
- return data_type
+ return self._parse_column_ops(data_type)
return this
+ def _parse_type_size(self) -> t.Optional[exp.Expression]:
+ this = self._parse_type()
+ if not this:
+ return None
+
+ return self.expression(
+ exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
+ )
+
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
index = self._index
@@ -2814,7 +2910,7 @@ class Parser(metaclass=_Parser):
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
- expressions = self._parse_csv(self._parse_conjunction)
+ expressions = self._parse_csv(self._parse_type_size)
if not expressions or not self._match(TokenType.R_PAREN):
self._retreat(index)
@@ -2858,13 +2954,14 @@ class Parser(metaclass=_Parser):
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
- if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
+ if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
elif (
- self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
+ self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE")
+ or type_token == TokenType.TIMESTAMPLTZ
):
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
- elif self._match(TokenType.WITHOUT_TIME_ZONE):
+ elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
if type_token == TokenType.TIME:
value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
else:
@@ -2909,7 +3006,7 @@ class Parser(metaclass=_Parser):
return self._parse_column_def(this)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
- if not self._match(TokenType.AT_TIME_ZONE):
+ if not self._match_text_seq("AT", "TIME", "ZONE"):
return this
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
@@ -2919,6 +3016,9 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Column, this=this)
elif not this:
return self._parse_bracket(this)
+ return self._parse_column_ops(this)
+
+ def _parse_column_ops(self, this: exp.Expression) -> exp.Expression:
this = self._parse_bracket(this)
while self._match_set(self.COLUMN_OPERATORS):
@@ -2929,7 +3029,7 @@ class Parser(metaclass=_Parser):
field = self._parse_types()
if not field:
self.raise_error("Expected type")
- elif op:
+ elif op and self._curr:
self._advance()
value = self._prev.text
field = (
@@ -2963,7 +3063,6 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.Dot, this=this, expression=field)
this = self._parse_bracket(this)
-
return this
def _parse_primary(self) -> t.Optional[exp.Expression]:
@@ -2989,12 +3088,9 @@ class Parser(metaclass=_Parser):
if query:
expressions = [query]
else:
- expressions = self._parse_csv(
- lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
- )
+ expressions = self._parse_csv(self._parse_expression)
- this = seq_get(expressions, 0)
- self._parse_query_modifiers(this)
+ this = self._parse_query_modifiers(seq_get(expressions, 0))
if isinstance(this, exp.Subqueryable):
this = self._parse_set_operations(
@@ -3065,20 +3161,12 @@ class Parser(metaclass=_Parser):
functions = self.FUNCTIONS
function = functions.get(upper)
- args = self._parse_csv(self._parse_lambda)
- if function and not anonymous:
- # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
- # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
- if count_params(function) == 2:
- params = None
- if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
- params = self._parse_csv(self._parse_lambda)
-
- this = function(args, params)
- else:
- this = function(args)
+ alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS
+ args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
+ if function and not anonymous:
+ this = function(args)
self.validate_expression(this, args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -3113,9 +3201,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=token.text)
- def _parse_national(self, token: Token) -> exp.Expression:
- return self.expression(exp.National, this=exp.Literal.string(token.text))
-
def _parse_session_parameter(self) -> exp.Expression:
kind = None
this = self._parse_id_var() or self._parse_primary()
@@ -3126,7 +3211,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.SessionParameter, this=this, kind=kind)
- def _parse_lambda(self) -> t.Optional[exp.Expression]:
+ def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.L_PAREN):
@@ -3149,7 +3234,7 @@ class Parser(metaclass=_Parser):
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else:
- this = self._parse_select_or_expression()
+ this = self._parse_select_or_expression(alias=alias)
if isinstance(this, exp.EQ):
left = this.this
@@ -3161,13 +3246,15 @@ class Parser(metaclass=_Parser):
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
- try:
- if self._parse_select(nested=True):
- return this
- except Exception:
- pass
- finally:
- self._retreat(index)
+ if not self.errors:
+ try:
+ if self._parse_select(nested=True):
+ return this
+ except ParseError:
+ pass
+ finally:
+ self.errors.clear()
+ self._retreat(index)
if not self._match(TokenType.L_PAREN):
return this
@@ -3227,13 +3314,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise())
def _parse_generated_as_identity(self) -> exp.Expression:
- if self._match(TokenType.BY_DEFAULT):
- this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
+ if self._match_text_seq("BY", "DEFAULT"):
+ on_null = self._match_pair(TokenType.ON, TokenType.NULL)
+ this = self.expression(
+ exp.GeneratedAsIdentityColumnConstraint, this=False, on_null=on_null
+ )
else:
self._match_text_seq("ALWAYS")
this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
- self._match_text_seq("AS", "IDENTITY")
+ self._match(TokenType.ALIAS)
+ identity = self._match_text_seq("IDENTITY")
+
if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
this.set("start", self._parse_bitwise())
@@ -3249,6 +3341,9 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("NO", "CYCLE"):
this.set("cycle", False)
+ if not identity:
+ this.set("expression", self._parse_bitwise())
+
self._match_r_paren()
return this
@@ -3307,9 +3402,10 @@ class Parser(metaclass=_Parser):
return self.CONSTRAINT_PARSERS[constraint](self)
def _parse_unique(self) -> exp.Expression:
- if not self._match(TokenType.L_PAREN, advance=False):
- return self.expression(exp.UniqueColumnConstraint)
- return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
+ self._match_text_seq("KEY")
+ return self.expression(
+ exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False))
+ )
def _parse_key_constraint_options(self) -> t.List[str]:
options = []
@@ -3321,9 +3417,9 @@ class Parser(metaclass=_Parser):
action = None
on = self._advance_any() and self._prev.text
- if self._match(TokenType.NO_ACTION):
+ if self._match_text_seq("NO", "ACTION"):
action = "NO ACTION"
- elif self._match(TokenType.CASCADE):
+ elif self._match_text_seq("CASCADE"):
action = "CASCADE"
elif self._match_pair(TokenType.SET, TokenType.NULL):
action = "SET NULL"
@@ -3348,7 +3444,7 @@ class Parser(metaclass=_Parser):
return options
- def _parse_references(self, match=True) -> t.Optional[exp.Expression]:
+ def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]:
if match and not self._match(TokenType.REFERENCES):
return None
@@ -3372,7 +3468,7 @@ class Parser(metaclass=_Parser):
kind = self._prev.text.lower()
- if self._match(TokenType.NO_ACTION):
+ if self._match_text_seq("NO", "ACTION"):
action = "NO ACTION"
elif self._match(TokenType.SET):
self._match_set((TokenType.NULL, TokenType.DEFAULT))
@@ -3396,11 +3492,19 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN, advance=False):
return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
- expressions = self._parse_wrapped_id_vars()
+ expressions = self._parse_wrapped_csv(self._parse_field)
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
+ @t.overload
+ def _parse_bracket(self, this: exp.Expression) -> exp.Expression:
+ ...
+
+ @t.overload
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ ...
+
+ def _parse_bracket(self, this):
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
@@ -3493,7 +3597,12 @@ class Parser(metaclass=_Parser):
this = self._parse_conjunction()
if not self._match(TokenType.ALIAS):
- self.raise_error("Expected AS after CAST")
+ if self._match(TokenType.COMMA):
+ return self.expression(
+ exp.CastToStrType, this=this, expression=self._parse_string()
+ )
+ else:
+ self.raise_error("Expected AS after CAST")
to = self._parse_types()
@@ -3524,7 +3633,7 @@ class Parser(metaclass=_Parser):
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
- if not self._match(TokenType.WITHIN_GROUP):
+ if not self._match_text_seq("WITHIN", "GROUP"):
self._retreat(index)
this = exp.GroupConcat.from_arg_list(args)
self.validate_expression(this, args)
@@ -3674,6 +3783,27 @@ class Parser(metaclass=_Parser):
exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier
)
+ # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16
+ def _parse_open_json(self) -> exp.Expression:
+ this = self._parse_bitwise()
+ path = self._match(TokenType.COMMA) and self._parse_string()
+
+ def _parse_open_json_column_def() -> exp.Expression:
+ this = self._parse_field(any_token=True)
+ kind = self._parse_types()
+ path = self._parse_string()
+ as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON)
+ return self.expression(
+ exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json
+ )
+
+ expressions = None
+ if self._match_pair(TokenType.R_PAREN, TokenType.WITH):
+ self._match_l_paren()
+ expressions = self._parse_csv(_parse_open_json_column_def)
+
+ return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions)
+
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@@ -3722,7 +3852,7 @@ class Parser(metaclass=_Parser):
position = None
collation = None
- if self._match_set(self.TRIM_TYPES):
+ if self._match_texts(self.TRIM_TYPES):
position = self._prev.text.upper()
expression = self._parse_bitwise()
@@ -3752,9 +3882,9 @@ class Parser(metaclass=_Parser):
def _parse_respect_or_ignore_nulls(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
- if self._match(TokenType.IGNORE_NULLS):
+ if self._match_text_seq("IGNORE", "NULLS"):
return self.expression(exp.IgnoreNulls, this=this)
- if self._match(TokenType.RESPECT_NULLS):
+ if self._match_text_seq("RESPECT", "NULLS"):
return self.expression(exp.RespectNulls, this=this)
return this
@@ -3767,7 +3897,7 @@ class Parser(metaclass=_Parser):
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
- if self._match(TokenType.WITHIN_GROUP):
+ if self._match_text_seq("WITHIN", "GROUP"):
order = self._parse_wrapped(self._parse_order)
this = self.expression(exp.WithinGroup, this=this, expression=order)
@@ -3846,10 +3976,11 @@ class Parser(metaclass=_Parser):
return {
"value": (
- self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
- )
- or self._parse_bitwise(),
- "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
+ (self._match_text_seq("UNBOUNDED") and "UNBOUNDED")
+ or (self._match_text_seq("CURRENT", "ROW") and "CURRENT ROW")
+ or self._parse_bitwise()
+ ),
+ "side": self._match_texts(self.WINDOW_SIDES) and self._prev.text,
}
def _parse_alias(
@@ -3956,7 +4087,7 @@ class Parser(metaclass=_Parser):
def _parse_parameter(self) -> exp.Expression:
wrapped = self._match(TokenType.L_BRACE)
- this = self._parse_var() or self._parse_primary()
+ this = self._parse_var() or self._parse_identifier() or self._parse_primary()
self._match(TokenType.R_BRACE)
return self.expression(exp.Parameter, this=this, wrapped=wrapped)
@@ -4011,26 +4142,33 @@ class Parser(metaclass=_Parser):
return this
- def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]:
- return self._parse_wrapped_csv(self._parse_id_var)
+ def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]:
+ return self._parse_wrapped_csv(self._parse_id_var, optional=optional)
def _parse_wrapped_csv(
- self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
+ self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
- return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
+ return self._parse_wrapped(
+ lambda: self._parse_csv(parse_method, sep=sep), optional=optional
+ )
- def _parse_wrapped(self, parse_method: t.Callable) -> t.Any:
- self._match_l_paren()
+ def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.Any:
+ wrapped = self._match(TokenType.L_PAREN)
+ if not wrapped and not optional:
+ self.raise_error("Expecting (")
parse_result = parse_method()
- self._match_r_paren()
+ if wrapped:
+ self._match_r_paren()
return parse_result
- def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
- return self._parse_select() or self._parse_set_operations(self._parse_expression())
+ def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
+ return self._parse_select() or self._parse_set_operations(
+ self._parse_expression() if alias else self._parse_conjunction()
+ )
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
- return self._parse_set_operations(
- self._parse_select(nested=True, parse_subquery_alias=False)
+ return self._parse_query_modifiers(
+ self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False))
)
def _parse_transaction(self) -> exp.Expression:
@@ -4391,11 +4529,11 @@ class Parser(metaclass=_Parser):
return None
- def _match_l_paren(self, expression=None):
+ def _match_l_paren(self, expression: t.Optional[exp.Expression] = None) -> None:
if not self._match(TokenType.L_PAREN, expression=expression):
self.raise_error("Expecting (")
- def _match_r_paren(self, expression=None):
+ def _match_r_paren(self, expression: t.Optional[exp.Expression] = None) -> None:
if not self._match(TokenType.R_PAREN, expression=expression):
self.raise_error("Expecting )")
@@ -4420,6 +4558,16 @@ class Parser(metaclass=_Parser):
return True
+ @t.overload
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
+ ...
+
+ @t.overload
+ def _replace_columns_with_dots(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ ...
+
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):
exp.replace_children(this, self._replace_columns_with_dots)
@@ -4433,9 +4581,15 @@ class Parser(metaclass=_Parser):
)
elif isinstance(this, exp.Identifier):
this = self.expression(exp.Var, this=this.name)
+
return this
- def _replace_lambda(self, node, lambda_variables):
+ def _replace_lambda(
+ self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str]
+ ) -> t.Optional[exp.Expression]:
+ if not node:
+ return node
+
for column in node.find_all(exp.Column):
if column.parts[0].name in lambda_variables:
dot_or_id = column.to_dot() if column.table else column.this
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 5fd96ef..eccad35 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -1,11 +1,10 @@
from __future__ import annotations
-import itertools
import math
import typing as t
from sqlglot import alias, exp
-from sqlglot.errors import UnsupportedError
+from sqlglot.helper import name_sequence
from sqlglot.optimizer.eliminate_joins import join_condition
@@ -105,13 +104,7 @@ class Step:
from_ = expression.args.get("from")
if isinstance(expression, exp.Select) and from_:
- from_ = from_.expressions
- if len(from_) > 1:
- raise UnsupportedError(
- "Multi-from statements are unsupported. Run it through the optimizer"
- )
-
- step = Scan.from_expression(from_[0], ctes)
+ step = Scan.from_expression(from_.this, ctes)
elif isinstance(expression, exp.Union):
step = SetOperation.from_expression(expression, ctes)
else:
@@ -128,7 +121,7 @@ class Step:
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = []
- sequence = itertools.count()
+ next_operand_name = name_sequence("_a_")
def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
@@ -136,7 +129,7 @@ class Step:
if isinstance(operand, exp.Column):
continue
if operand not in operands:
- operands[operand] = f"_a_{next(sequence)}"
+ operands[operand] = next_operand_name()
operand.replace(exp.column(operands[operand], quoted=True))
for e in expression.expressions:
@@ -310,7 +303,7 @@ class Join(Step):
for join in joins:
source_key, join_key, condition = join_condition(join)
step.joins[join.this.alias_or_name] = {
- "side": join.side,
+ "side": join.side, # type: ignore
"join_key": join_key,
"source_key": source_key,
"condition": condition,
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 5d60eb9..f1c4a09 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,6 +5,8 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
+from sqlglot._typing import T
+from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
@@ -17,62 +19,83 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
-T = t.TypeVar("T")
-
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
+ dialect: DialectType
+
@abc.abstractmethod
def add_table(
- self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ self,
+ table: exp.Table | str,
+ column_mapping: t.Optional[ColumnMapping] = None,
+ dialect: DialectType = None,
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
Args:
- table: table expression instance or string representing the table.
+ table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
@abc.abstractmethod
- def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
+ def column_names(
+ self,
+ table: exp.Table | str,
+ only_visible: bool = False,
+ dialect: DialectType = None,
+ ) -> t.List[str]:
"""
Get the column names for a table.
Args:
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
Returns:
The list of column names.
"""
@abc.abstractmethod
- def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
+ def get_column_type(
+ self,
+ table: exp.Table | str,
+ column: exp.Column,
+ dialect: DialectType = None,
+ ) -> exp.DataType:
"""
- Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
+ Get the `sqlglot.exp.DataType` type of a column in the schema.
Args:
table: the source table.
column: the target column.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
Returns:
The resulting column type.
"""
@property
+ @abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
"""
Table arguments this schema support, e.g. `("this", "db", "catalog")`
"""
- raise NotImplementedError
+
+ @property
+ def empty(self) -> bool:
+ """Returns whether or not the schema is empty."""
+ return True
class AbstractMappingSchema(t.Generic[T]):
def __init__(
self,
- mapping: dict | None = None,
+ mapping: t.Optional[t.Dict] = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
@@ -80,6 +103,10 @@ class AbstractMappingSchema(t.Generic[T]):
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
+ @property
+ def empty(self) -> bool:
+ return not self.mapping
+
def _depth(self) -> int:
return dict_depth(self.mapping)
@@ -110,8 +137,10 @@ class AbstractMappingSchema(t.Generic[T]):
if value == 0:
return None
- elif value == 1:
+
+ if value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
+
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
@@ -119,12 +148,13 @@ class AbstractMappingSchema(t.Generic[T]):
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
- return self._nested_get(parts, raise_on_missing=raise_on_missing)
- def _nested_get(
+ return self.nested_get(parts, raise_on_missing=raise_on_missing)
+
+ def nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
- return _nested_get(
+ return nested_get(
d or self.mapping,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
@@ -136,17 +166,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Schema based on a nested mapping.
Args:
- schema (dict): Mapping in one of the following forms:
+ schema: Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
4. None - Tables will be added later
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
+ visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
+ dialect: The dialect to be used for custom type mappings & parsing string arguments.
+ normalize: Whether to normalize identifier names according to the given dialect or not.
"""
def __init__(
@@ -154,10 +185,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: DialectType = None,
+ normalize: bool = True,
) -> None:
self.dialect = dialect
self.visible = visible or {}
+ self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
+
super().__init__(self._normalize(schema or {}))
@classmethod
@@ -179,7 +213,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
)
def add_table(
- self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ self,
+ table: exp.Table | str,
+ column_mapping: t.Optional[ColumnMapping] = None,
+ dialect: DialectType = None,
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
@@ -187,10 +224,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
- normalized_table = self._normalize_table(self._ensure_table(table))
+ normalized_table = self._normalize_table(
+ self._ensure_table(table, dialect=dialect), dialect=dialect
+ )
normalized_column_mapping = {
- self._normalize_name(key): value
+ self._normalize_name(key, dialect=dialect): value
for key, value in ensure_column_mapping(column_mapping).items()
}
@@ -200,38 +240,51 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
parts = self.table_parts(normalized_table)
- _nested_set(
- self.mapping,
- tuple(reversed(parts)),
- normalized_column_mapping,
- )
+ nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
new_trie([parts], self.mapping_trie)
- def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
- table_ = self._normalize_table(self._ensure_table(table))
- schema = self.find(table_)
+ def column_names(
+ self,
+ table: exp.Table | str,
+ only_visible: bool = False,
+ dialect: DialectType = None,
+ ) -> t.List[str]:
+ normalized_table = self._normalize_table(
+ self._ensure_table(table, dialect=dialect), dialect=dialect
+ )
+ schema = self.find(normalized_table)
if schema is None:
return []
if not only_visible or not self.visible:
return list(schema)
- visible = self._nested_get(self.table_parts(table_), self.visible)
- return [col for col in schema if col in visible] # type: ignore
+ visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
+ return [col for col in schema if col in visible]
- def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
- column_name = self._normalize_name(column if isinstance(column, str) else column.this)
- table_ = self._normalize_table(self._ensure_table(table))
+ def get_column_type(
+ self,
+ table: exp.Table | str,
+ column: exp.Column,
+ dialect: DialectType = None,
+ ) -> exp.DataType:
+ normalized_table = self._normalize_table(
+ self._ensure_table(table, dialect=dialect), dialect=dialect
+ )
+ normalized_column_name = self._normalize_name(
+ column if isinstance(column, str) else column.this, dialect=dialect
+ )
- table_schema = self.find(table_, raise_on_missing=False)
+ table_schema = self.find(normalized_table, raise_on_missing=False)
if table_schema:
- column_type = table_schema.get(column_name)
+ column_type = table_schema.get(normalized_column_name)
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
- return self._to_data_type(column_type.upper())
+ return self._to_data_type(column_type.upper(), dialect=dialect)
+
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType.build("unknown")
@@ -250,81 +303,88 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
- columns = _nested_get(schema, *zip(keys, keys))
+ columns = nested_get(schema, *zip(keys, keys))
assert columns is not None
- normalized_keys = [self._normalize_name(key) for key in keys]
+ normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
for column_name, column_type in columns.items():
- _nested_set(
+ nested_set(
normalized_mapping,
- normalized_keys + [self._normalize_name(column_name)],
+ normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
column_type,
)
return normalized_mapping
- def _normalize_table(self, table: exp.Table) -> exp.Table:
+ def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
normalized_table = table.copy()
+
for arg in TABLE_ARGS:
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
- normalized_table.set(arg, self._normalize_name(value))
+ normalized_table.set(
+ arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
+ )
return normalized_table
- def _normalize_name(self, name: str | exp.Identifier) -> str:
+ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
+ dialect = dialect or self.dialect
+
try:
- identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
+ identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
- return identifier.name if identifier.quoted else identifier.name.lower()
+ name = identifier.name
+
+ if not self.normalize or identifier.quoted:
+ return name
+
+ return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
- def _ensure_table(self, table: exp.Table | str) -> exp.Table:
- if isinstance(table, exp.Table):
- return table
-
- table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
- if not table_:
- raise SchemaError(f"Not a valid table '{table}'")
+ def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
+ return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
- return table_
-
- def _to_data_type(self, schema_type: str) -> exp.DataType:
+ def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
- Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
+ Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
Args:
schema_type: the type we want to convert.
+ dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
Returns:
The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
+ dialect = dialect or self.dialect
+
try:
- expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
- if expression is None:
- raise ValueError(f"Could not parse {schema_type}")
- self._type_mapping_cache[schema_type] = expression # type: ignore
+ expression = exp.DataType.build(schema_type, dialect=dialect)
+ self._type_mapping_cache[schema_type] = expression
except AttributeError:
- raise SchemaError(f"Failed to convert type {schema_type}")
+ in_dialect = f" in dialect {dialect}" if dialect else ""
+ raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
return self._type_mapping_cache[schema_type]
-def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
+def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
- return MappingSchema(schema, dialect=dialect)
+ return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
- if isinstance(mapping, dict):
+ if mapping is None:
+ return {}
+ elif isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
col_name_type_strs = [x.strip() for x in mapping.split(",")]
@@ -334,11 +394,10 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
- return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
- elif mapping is None:
- return {}
+
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
@@ -353,10 +412,11 @@ def flatten_schema(
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 1:
tables.append(keys + [k])
+
return tables
-def _nested_get(
+def nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
"""
@@ -378,18 +438,19 @@ def _nested_get(
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}: {key}")
return None
+
return d
-def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
+def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
Example:
- >>> _nested_set({}, ["top_key", "second_key"], "value")
+ >>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
- >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
+ >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Args:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 5e50b7c..ad329d2 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -51,7 +51,6 @@ class TokenType(AutoName):
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()
- NATIONAL = auto()
DAMP = auto()
BLOCK_START = auto()
@@ -72,6 +71,8 @@ class TokenType(AutoName):
BIT_STRING = auto()
HEX_STRING = auto()
BYTE_STRING = auto()
+ NATIONAL_STRING = auto()
+ RAW_STRING = auto()
# types
BIT = auto()
@@ -110,6 +111,7 @@ class TokenType(AutoName):
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATETIME = auto()
+ DATETIME64 = auto()
DATE = auto()
UUID = auto()
GEOGRAPHY = auto()
@@ -142,30 +144,22 @@ class TokenType(AutoName):
ARRAY = auto()
ASC = auto()
ASOF = auto()
- AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto()
BEGIN = auto()
BETWEEN = auto()
- BOTH = auto()
- BUCKET = auto()
- BY_DEFAULT = auto()
CACHE = auto()
- CASCADE = auto()
CASE = auto()
CHARACTER_SET = auto()
- CLUSTER_BY = auto()
COLLATE = auto()
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
- COMPOUND = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
CUBE = auto()
CURRENT_DATE = auto()
CURRENT_DATETIME = auto()
- CURRENT_ROW = auto()
CURRENT_TIME = auto()
CURRENT_TIMESTAMP = auto()
CURRENT_USER = auto()
@@ -174,8 +168,6 @@ class TokenType(AutoName):
DESC = auto()
DESCRIBE = auto()
DISTINCT = auto()
- DISTINCT_FROM = auto()
- DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
@@ -189,7 +181,6 @@ class TokenType(AutoName):
FILTER = auto()
FINAL = auto()
FIRST = auto()
- FOLLOWING = auto()
FOR = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
@@ -203,7 +194,6 @@ class TokenType(AutoName):
HAVING = auto()
HINT = auto()
IF = auto()
- IGNORE_NULLS = auto()
ILIKE = auto()
ILIKE_ANY = auto()
IN = auto()
@@ -222,36 +212,27 @@ class TokenType(AutoName):
KEEP = auto()
LANGUAGE = auto()
LATERAL = auto()
- LAZY = auto()
- LEADING = auto()
LEFT = auto()
LIKE = auto()
LIKE_ANY = auto()
LIMIT = auto()
- LOAD_DATA = auto()
- LOCAL = auto()
+ LOAD = auto()
+ LOCK = auto()
MAP = auto()
MATCH_RECOGNIZE = auto()
- MATERIALIZED = auto()
MERGE = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
NEXT_VALUE_FOR = auto()
- NO_ACTION = auto()
NOTNULL = auto()
NULL = auto()
- NULLS_FIRST = auto()
- NULLS_LAST = auto()
OFFSET = auto()
ON = auto()
- ONLY = auto()
- OPTIONS = auto()
ORDER_BY = auto()
ORDERED = auto()
ORDINALITY = auto()
OUTER = auto()
- OUT_OF = auto()
OVER = auto()
OVERLAPS = auto()
OVERWRITE = auto()
@@ -261,7 +242,6 @@ class TokenType(AutoName):
PIVOT = auto()
PLACEHOLDER = auto()
PRAGMA = auto()
- PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
@@ -271,7 +251,6 @@ class TokenType(AutoName):
RANGE = auto()
RECURSIVE = auto()
REPLACE = auto()
- RESPECT_NULLS = auto()
RETURNING = auto()
REFERENCES = auto()
RIGHT = auto()
@@ -280,28 +259,23 @@ class TokenType(AutoName):
ROLLUP = auto()
ROW = auto()
ROWS = auto()
- SEED = auto()
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
SERDE_PROPERTIES = auto()
SET = auto()
+ SETTINGS = auto()
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
- SORTKEY = auto()
- SORT_BY = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TOP = auto()
THEN = auto()
- TRAILING = auto()
TRUE = auto()
- UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
- UNLOGGED = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
@@ -314,15 +288,11 @@ class TokenType(AutoName):
WHERE = auto()
WINDOW = auto()
WITH = auto()
- WITH_TIME_ZONE = auto()
- WITH_LOCAL_TIME_ZONE = auto()
- WITHIN_GROUP = auto()
- WITHOUT_TIME_ZONE = auto()
UNIQUE = auto()
class Token:
- __slots__ = ("token_type", "text", "line", "col", "end", "comments")
+ __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
@classmethod
def number(cls, number: int) -> Token:
@@ -350,22 +320,28 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
+ start: int = 0,
end: int = 0,
comments: t.List[str] = [],
) -> None:
+ """Token initializer.
+
+ Args:
+ token_type: The TokenType Enum.
+ text: The text of the token.
+ line: The line that the token ends on.
+ col: The column that the token ends on.
+ start: The start index of the token.
+ end: The ending index of the token.
+ """
self.token_type = token_type
self.text = text
self.line = line
- size = len(text)
self.col = col
- self.end = end if end else size
+ self.start = start
+ self.end = end
self.comments = comments
- @property
- def start(self) -> int:
- """Returns the start of the token."""
- return self.end - len(self.text)
-
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
@@ -375,15 +351,31 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
- klass._QUOTES = {
- f"{prefix}{s}": e
- for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
- for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
+ def _convert_quotes(arr: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
+ return dict(
+ (item, item) if isinstance(item, str) else (item[0], item[1]) for item in arr
+ )
+
+ def _quotes_to_format(
+ token_type: TokenType, arr: t.List[str | t.Tuple[str, str]]
+ ) -> t.Dict[str, t.Tuple[str, TokenType]]:
+ return {k: (v, token_type) for k, v in _convert_quotes(arr).items()}
+
+ klass._QUOTES = _convert_quotes(klass.QUOTES)
+ klass._IDENTIFIERS = _convert_quotes(klass.IDENTIFIERS)
+
+ klass._FORMAT_STRINGS = {
+ **{
+ p + s: (e, TokenType.NATIONAL_STRING)
+ for s, e in klass._QUOTES.items()
+ for p in ("n", "N")
+ },
+ **_quotes_to_format(TokenType.BIT_STRING, klass.BIT_STRINGS),
+ **_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS),
+ **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
+ **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
}
- klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
- klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
- klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
- klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
+
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
klass._COMMENTS = dict(
@@ -393,23 +385,17 @@ class _Tokenizer(type):
klass.KEYWORD_TRIE = new_trie(
key.upper()
- for key in {
- **klass.KEYWORDS,
- **{comment: TokenType.COMMENT for comment in klass._COMMENTS},
- **{quote: TokenType.QUOTE for quote in klass._QUOTES},
- **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
- **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
- **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
- }
+ for key in (
+ *klass.KEYWORDS,
+ *klass._COMMENTS,
+ *klass._QUOTES,
+ *klass._FORMAT_STRINGS,
+ )
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
- @staticmethod
- def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
- return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
-
class Tokenizer(metaclass=_Tokenizer):
SINGLE_TOKENS = {
@@ -450,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
+ RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
@@ -457,9 +444,7 @@ class Tokenizer(metaclass=_Tokenizer):
VAR_SINGLE_TOKENS: t.Set[str] = set()
_COMMENTS: t.Dict[str, str] = {}
- _BIT_STRINGS: t.Dict[str, str] = {}
- _BYTE_STRINGS: t.Dict[str, str] = {}
- _HEX_STRINGS: t.Dict[str, str] = {}
+ _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
_IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
_QUOTES: t.Dict[str, str] = {}
@@ -495,30 +480,22 @@ class Tokenizer(metaclass=_Tokenizer):
"ANY": TokenType.ANY,
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
- "AT TIME ZONE": TokenType.AT_TIME_ZONE,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
- "BOTH": TokenType.BOTH,
- "BUCKET": TokenType.BUCKET,
- "BY DEFAULT": TokenType.BY_DEFAULT,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
- "CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET,
- "CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
- "COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE,
"CURRENT_DATE": TokenType.CURRENT_DATE,
- "CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
"CURRENT_USER": TokenType.CURRENT_USER,
@@ -528,8 +505,6 @@ class Tokenizer(metaclass=_Tokenizer):
"DESC": TokenType.DESC,
"DESCRIBE": TokenType.DESCRIBE,
"DISTINCT": TokenType.DISTINCT,
- "DISTINCT FROM": TokenType.DISTINCT_FROM,
- "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
@@ -544,18 +519,18 @@ class Tokenizer(metaclass=_Tokenizer):
"FIRST": TokenType.FIRST,
"FULL": TokenType.FULL,
"FUNCTION": TokenType.FUNCTION,
- "FOLLOWING": TokenType.FOLLOWING,
"FOR": TokenType.FOR,
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
+ "GEOGRAPHY": TokenType.GEOGRAPHY,
+ "GEOMETRY": TokenType.GEOMETRY,
"GLOB": TokenType.GLOB,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
- "IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
"INET": TokenType.INET,
@@ -569,34 +544,25 @@ class Tokenizer(metaclass=_Tokenizer):
"JOIN": TokenType.JOIN,
"KEEP": TokenType.KEEP,
"LATERAL": TokenType.LATERAL,
- "LAZY": TokenType.LAZY,
- "LEADING": TokenType.LEADING,
"LEFT": TokenType.LEFT,
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
- "LOAD DATA": TokenType.LOAD_DATA,
- "LOCAL": TokenType.LOCAL,
- "MATERIALIZED": TokenType.MATERIALIZED,
+ "LOAD": TokenType.LOAD,
+ "LOCK": TokenType.LOCK,
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
- "NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
"NOTNULL": TokenType.NOTNULL,
"NULL": TokenType.NULL,
- "NULLS FIRST": TokenType.NULLS_FIRST,
- "NULLS LAST": TokenType.NULLS_LAST,
"OBJECT": TokenType.OBJECT,
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
- "ONLY": TokenType.ONLY,
- "OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY,
"ORDINALITY": TokenType.ORDINALITY,
"OUTER": TokenType.OUTER,
- "OUT OF": TokenType.OUT_OF,
"OVER": TokenType.OVER,
"OVERLAPS": TokenType.OVERLAPS,
"OVERWRITE": TokenType.OVERWRITE,
@@ -607,7 +573,6 @@ class Tokenizer(metaclass=_Tokenizer):
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRAGMA": TokenType.PRAGMA,
- "PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
"QUALIFY": TokenType.QUALIFY,
@@ -615,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
"REPLACE": TokenType.REPLACE,
- "RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
@@ -624,25 +588,20 @@ class Tokenizer(metaclass=_Tokenizer):
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
"SCHEMA": TokenType.SCHEMA,
- "SEED": TokenType.SEED,
"SELECT": TokenType.SELECT,
"SEMI": TokenType.SEMI,
"SET": TokenType.SET,
+ "SETTINGS": TokenType.SETTINGS,
"SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
- "SORTKEY": TokenType.SORTKEY,
- "SORT BY": TokenType.SORT_BY,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
- "TRAILING": TokenType.TRAILING,
- "UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
- "UNLOGGED": TokenType.UNLOGGED,
"UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
@@ -656,10 +615,6 @@ class Tokenizer(metaclass=_Tokenizer):
"WHERE": TokenType.WHERE,
"WINDOW": TokenType.WINDOW,
"WITH": TokenType.WITH,
- "WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
- "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
- "WITHIN GROUP": TokenType.WITHIN_GROUP,
- "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"APPLY": TokenType.APPLY,
"ARRAY": TokenType.ARRAY,
"BIT": TokenType.BIT,
@@ -718,15 +673,6 @@ class Tokenizer(metaclass=_Tokenizer):
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.ALTER,
- "ALTER AGGREGATE": TokenType.COMMAND,
- "ALTER DEFAULT": TokenType.COMMAND,
- "ALTER DOMAIN": TokenType.COMMAND,
- "ALTER ROLE": TokenType.COMMAND,
- "ALTER RULE": TokenType.COMMAND,
- "ALTER SEQUENCE": TokenType.COMMAND,
- "ALTER TYPE": TokenType.COMMAND,
- "ALTER USER": TokenType.COMMAND,
- "ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"COMMENT": TokenType.COMMENT,
@@ -790,7 +736,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._start = 0
self._current = 0
self._line = 1
- self._col = 1
+ self._col = 0
self._comments: t.List[str] = []
self._char = ""
@@ -803,13 +749,12 @@ class Tokenizer(metaclass=_Tokenizer):
self.reset()
self.sql = sql
self.size = len(sql)
+
try:
self._scan()
except Exception as e:
- start = self._current - 50
- end = self._current + 50
- start = start if start > 0 else 0
- end = end if end < self.size else self.size - 1
+ start = max(self._current - 50, 0)
+ end = min(self._current + 50, self.size - 1)
context = self.sql[start:end]
raise ValueError(f"Error tokenizing '{context}'") from e
@@ -834,17 +779,17 @@ class Tokenizer(metaclass=_Tokenizer):
if until and until():
break
- if self.tokens:
+ if self.tokens and self._comments:
self.tokens[-1].comments.extend(self._comments)
def _chars(self, size: int) -> str:
if size == 1:
return self._char
+
start = self._current - 1
end = start + size
- if end <= self.size:
- return self.sql[start:end]
- return ""
+
+ return self.sql[start:end] if end <= self.size else ""
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
@@ -859,6 +804,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._peek = "" if self._end else self.sql[self._current]
if alnum and self._char.isalnum():
+ # Here we use local variables instead of attributes for better performance
_col = self._col
_current = self._current
_end = self._end
@@ -885,11 +831,12 @@ class Tokenizer(metaclass=_Tokenizer):
self.tokens.append(
Token(
token_type,
- self._text if text is None else text,
- self._line,
- self._col,
- self._current,
- self._comments,
+ text=self._text if text is None else text,
+ line=self._line,
+ col=self._col,
+ start=self._start,
+ end=self._current - 1,
+ comments=self._comments,
)
)
self._comments = []
@@ -929,6 +876,7 @@ class Tokenizer(metaclass=_Tokenizer):
break
if result == 2:
word = chars
+
size += 1
end = self._current - 1 + size
@@ -946,6 +894,7 @@ class Tokenizer(metaclass=_Tokenizer):
else:
skip = True
else:
+ char = ""
chars = " "
word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word
@@ -959,8 +908,6 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word):
return
- if self._scan_formatted_string(word):
- return
if self._scan_comment(word):
return
@@ -1004,9 +951,9 @@ class Tokenizer(metaclass=_Tokenizer):
if self._char == "0":
peek = self._peek.upper()
if peek == "B":
- return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
+ return self._scan_bits() if self.BIT_STRINGS else self._add(TokenType.NUMBER)
elif peek == "X":
- return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)
+ return self._scan_hex() if self.HEX_STRINGS else self._add(TokenType.NUMBER)
decimal = False
scientific = 0
@@ -1075,37 +1022,24 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
- def _scan_string(self, quote: str) -> bool:
- quote_end = self._QUOTES.get(quote)
- if quote_end is None:
- return False
+ def _scan_string(self, start: str) -> bool:
+ base = None
+ token_type = TokenType.STRING
- self._advance(len(quote))
- text = self._extract_string(quote_end)
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
- self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
- return True
+ if start in self._QUOTES:
+ end = self._QUOTES[start]
+ elif start in self._FORMAT_STRINGS:
+ end, token_type = self._FORMAT_STRINGS[start]
- # X'1234', b'0110', E'\\\\\' etc.
- def _scan_formatted_string(self, string_start: str) -> bool:
- if string_start in self._HEX_STRINGS:
- delimiters = self._HEX_STRINGS
- token_type = TokenType.HEX_STRING
- base = 16
- elif string_start in self._BIT_STRINGS:
- delimiters = self._BIT_STRINGS
- token_type = TokenType.BIT_STRING
- base = 2
- elif string_start in self._BYTE_STRINGS:
- delimiters = self._BYTE_STRINGS
- token_type = TokenType.BYTE_STRING
- base = None
+ if token_type == TokenType.HEX_STRING:
+ base = 16
+ elif token_type == TokenType.BIT_STRING:
+ base = 2
else:
return False
- self._advance(len(string_start))
- string_end = delimiters[string_start]
- text = self._extract_string(string_end)
+ self._advance(len(start))
+ text = self._extract_string(end)
if base:
try:
@@ -1114,6 +1048,8 @@ class Tokenizer(metaclass=_Tokenizer):
raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
+ else:
+ text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(token_type, text)
return True
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 3643cd7..a1ec1bd 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
-from sqlglot.helper import find_new_name
+from sqlglot.helper import find_new_name, name_sequence
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@@ -63,16 +63,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
outer_selects = expression.selects
row_number = find_new_name(expression.named_selects, "_row_number")
- window = exp.Window(
- this=exp.RowNumber(),
- partition_by=distinct_cols,
- )
+ window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
order = expression.args.get("order")
+
if order:
window.set("order", order.pop().copy())
+
window = exp.alias_(window, row_number)
expression.select(window, copy=False)
+
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
+
return expression
@@ -93,7 +94,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for select in expression.selects:
if not select.alias_or_name:
alias = find_new_name(taken, "_c")
- select.replace(exp.alias_(select.copy(), alias))
+ select.replace(exp.alias_(select, alias))
taken.add(alias)
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
@@ -102,8 +103,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
- expression.select(exp.alias_(expr.copy(), alias), copy=False)
+ expression.select(exp.alias_(expr, alias), copy=False)
column = exp.column(alias)
+
if isinstance(expr.parent, exp.Qualify):
qualify_filters = column
else:
@@ -123,6 +125,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
"""
for node in expression.find_all(exp.DataType):
node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
+
return expression
@@ -147,6 +150,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
),
)
+
return expression
@@ -156,7 +160,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import build_scope
taken_select_names = set(expression.named_selects)
- taken_source_names = set(build_scope(expression).selected_sources)
+ scope = build_scope(expression)
+ if not scope:
+ return expression
+ taken_source_names = set(scope.selected_sources)
for select in expression.selects:
to_replace = select
@@ -226,6 +233,7 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
else node,
copy=False,
)
+
return expression
@@ -242,12 +250,20 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
return expression
-def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
- if isinstance(expression, exp.Pivot):
- expression.args["field"].transform(
- lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
- copy=False,
- )
+def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.With) and expression.recursive:
+ next_name = name_sequence("_c_")
+
+ for cte in expression.expressions:
+ if not cte.args["alias"].columns:
+ query = cte.this
+ if isinstance(query, exp.Union):
+ query = query.this
+
+ cte.args["alias"].set(
+ "columns",
+ [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
+ )
return expression