summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py97
-rw-r--r--sqlglot/dialects/clickhouse.py279
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py99
-rw-r--r--sqlglot/dialects/drill.py8
-rw-r--r--sqlglot/dialects/duckdb.py33
-rw-r--r--sqlglot/dialects/hive.py38
-rw-r--r--sqlglot/dialects/mysql.py64
-rw-r--r--sqlglot/dialects/oracle.py10
-rw-r--r--sqlglot/dialects/postgres.py45
-rw-r--r--sqlglot/dialects/presto.py94
-rw-r--r--sqlglot/dialects/redshift.py58
-rw-r--r--sqlglot/dialects/snowflake.py111
-rw-r--r--sqlglot/dialects/spark.py6
-rw-r--r--sqlglot/dialects/spark2.py96
-rw-r--r--sqlglot/dialects/sqlite.py19
-rw-r--r--sqlglot/dialects/starrocks.py10
-rw-r--r--sqlglot/dialects/tableau.py39
-rw-r--r--sqlglot/dialects/teradata.py14
-rw-r--r--sqlglot/dialects/trino.py2
-rw-r--r--sqlglot/dialects/tsql.py90
21 files changed, 744 insertions, 470 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 9705b35..1a58337 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,5 +1,3 @@
-"""Supports BigQuery Standard SQL."""
-
from __future__ import annotations
import re
@@ -18,11 +16,9 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
-from sqlglot.helper import seq_get
+from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
-E = t.TypeVar("E", bound=exp.Expression)
-
def _date_add_sql(
data_type: str, kind: str
@@ -96,19 +92,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
if isinstance(expression, exp.Select):
- unnests = {
- unnest.alias
- for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
- if isinstance(unnest, exp.Unnest) and unnest.alias
- }
-
- if unnests:
- expression = expression.copy()
-
- for select in expression.expressions:
- for column in select.find_all(exp.Column):
- if column.table in unnests:
- column.set("table", None)
+ for unnest in expression.find_all(exp.Unnest):
+ if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
+ for select in expression.selects:
+ for column in select.find_all(exp.Column):
+ if column.table == unnest.alias:
+ column.set("table", None)
return expression
@@ -127,16 +116,20 @@ class BigQuery(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- QUOTES = [
- (prefix + quote, quote) if prefix else quote
- for quote in ["'", '"', '"""', "'''"]
- for prefix in ["", "r", "R"]
- ]
+ QUOTES = ["'", '"', '"""', "'''"]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
+
HEX_STRINGS = [("0x", ""), ("0X", "")]
- BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
+
+ BYTE_STRINGS = [
+ (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B")
+ ]
+
+ RAW_STRINGS = [
+ (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R")
+ ]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -144,11 +137,11 @@ class BigQuery(Dialect):
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
+ "BYTES": TokenType.BINARY,
"DECLARE": TokenType.COMMAND,
- "GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
- "BYTES": TokenType.BINARY,
+ "RECORD": TokenType.STRUCT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@@ -161,7 +154,7 @@ class BigQuery(Dialect):
LOG_DEFAULTS_TO_LN = True
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
@@ -191,28 +184,28 @@ class BigQuery(Dialect):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
- **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
+ **parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
}
NESTED_TYPE_TOKENS = {
- *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
+ *parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
ID_VAR_TOKENS = {
- *parser.Parser.ID_VAR_TOKENS, # type: ignore
+ *parser.Parser.ID_VAR_TOKENS,
TokenType.VALUES,
}
PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS, # type: ignore
+ **parser.Parser.PROPERTY_PARSERS,
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
@@ -220,19 +213,50 @@ class BigQuery(Dialect):
}
CONSTRAINT_PARSERS = {
- **parser.Parser.CONSTRAINT_PARSERS, # type: ignore
+ **parser.Parser.CONSTRAINT_PARSERS,
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
+ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
+ this = super()._parse_table_part(schema=schema)
+
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
+ if isinstance(this, exp.Identifier):
+ table_name = this.name
+ while self._match(TokenType.DASH, advance=False) and self._next:
+ self._advance(2)
+ table_name += f"-{self._prev.text}"
+
+ this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
+
+ return this
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ table = super()._parse_table_parts(schema=schema)
+ if isinstance(table.this, exp.Identifier) and "." in table.name:
+ catalog, db, this, *rest = (
+ t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
+ for x in split_num_words(table.name, ".", 3)
+ )
+
+ if rest and this:
+ this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
+
+ table = exp.Table(this=this, db=db, catalog=catalog)
+
+ return table
+
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
+ RENAME_TABLE_WITH_DB = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
+ exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
@@ -259,6 +283,7 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
+ exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -274,7 +299,7 @@ class BigQuery(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BINARY: "BYTES",
@@ -297,7 +322,7 @@ class BigQuery(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 2a49066..c8a9525 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -3,11 +3,16 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
-from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
+from sqlglot.dialects.dialect import (
+ Dialect,
+ inline_array_sql,
+ no_pivot_sql,
+ rename_func,
+ var_map_sql,
+)
from sqlglot.errors import ParseError
-from sqlglot.helper import ensure_list, seq_get
from sqlglot.parser import parse_var_map
-from sqlglot.tokens import TokenType
+from sqlglot.tokens import Token, TokenType
def _lower_func(sql: str) -> str:
@@ -28,65 +33,122 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
- "GLOBAL": TokenType.GLOBAL,
- "DATETIME64": TokenType.DATETIME,
+ "ATTACH": TokenType.COMMAND,
+ "DATETIME64": TokenType.DATETIME64,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
- "INT8": TokenType.TINYINT,
- "UINT8": TokenType.UTINYINT,
+ "GLOBAL": TokenType.GLOBAL,
+ "INT128": TokenType.INT128,
"INT16": TokenType.SMALLINT,
- "UINT16": TokenType.USMALLINT,
+ "INT256": TokenType.INT256,
"INT32": TokenType.INT,
- "UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
- "UINT64": TokenType.UBIGINT,
- "INT128": TokenType.INT128,
+ "INT8": TokenType.TINYINT,
+ "MAP": TokenType.MAP,
+ "TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
- "INT256": TokenType.INT256,
+ "UINT16": TokenType.USMALLINT,
"UINT256": TokenType.UINT256,
- "TUPLE": TokenType.STRUCT,
+ "UINT32": TokenType.UINT,
+ "UINT64": TokenType.UBIGINT,
+ "UINT8": TokenType.UTINYINT,
}
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
- "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
- this=seq_get(args, 0),
- time=seq_get(args, 1),
- decay=seq_get(params, 0),
- ),
- "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
- this=seq_get(args, 0), size=seq_get(params, 0)
- ),
- "HISTOGRAM": lambda params, args: exp.Histogram(
- this=seq_get(args, 0), bins=seq_get(params, 0)
- ),
+ **parser.Parser.FUNCTIONS,
+ "ANY": exp.AnyValue.from_arg_list,
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
- "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
- "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
- "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
+ "UNIQ": exp.ApproxDistinct.from_arg_list,
+ }
+
+ FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
+
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS,
+ "QUANTILE": lambda self: self._parse_quantile(),
}
- FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("MATCH")
+ NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy()
+ NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY)
+
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
and self._parse_in(this, is_global=True),
}
- JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
+ # The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to
+ # the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler.
+ COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy()
+ COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER)
+
+ JOIN_KINDS = {
+ *parser.Parser.JOIN_KINDS,
+ TokenType.ANY,
+ TokenType.ASOF,
+ TokenType.ANTI,
+ TokenType.SEMI,
+ }
- TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
+ TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
+ TokenType.ANY,
+ TokenType.ASOF,
+ TokenType.SEMI,
+ TokenType.ANTI,
+ TokenType.SETTINGS,
+ TokenType.FORMAT,
+ }
LOG_DEFAULTS_TO_LN = True
- def _parse_in(
- self, this: t.Optional[exp.Expression], is_global: bool = False
- ) -> exp.Expression:
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ "settings": lambda self: self._parse_csv(self._parse_conjunction)
+ if self._match(TokenType.SETTINGS)
+ else None,
+ "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
+ }
+
+ def _parse_conjunction(self) -> t.Optional[exp.Expression]:
+ this = super()._parse_conjunction()
+
+ if self._match(TokenType.PLACEHOLDER):
+ return self.expression(
+ exp.If,
+ this=this,
+ true=self._parse_conjunction(),
+ false=self._match(TokenType.COLON) and self._parse_conjunction(),
+ )
+
+ return this
+
+ def _parse_placeholder(self) -> t.Optional[exp.Expression]:
+ """
+ Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
+ https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
+ """
+ if not self._match(TokenType.L_BRACE):
+ return None
+
+ this = self._parse_id_var()
+ self._match(TokenType.COLON)
+ kind = self._parse_types(check_func=False) or (
+ self._match_text_seq("IDENTIFIER") and "Identifier"
+ )
+
+ if not kind:
+ self.raise_error("Expecting a placeholder type or 'Identifier' for tables")
+ elif not self._match(TokenType.R_BRACE):
+ self.raise_error("Expecting }")
+
+ return self.expression(exp.Placeholder, this=this, kind=kind)
+
+ def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In:
this = super()._parse_in(this)
this.set("is_global", is_global)
return this
@@ -120,81 +182,142 @@ class ClickHouse(Dialect):
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
+ def _parse_join_side_and_kind(
+ self,
+ ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
+ is_global = self._match(TokenType.GLOBAL) and self._prev
+ kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
+ if kind_pre:
+ kind = self._match_set(self.JOIN_KINDS) and self._prev
+ side = self._match_set(self.JOIN_SIDES) and self._prev
+ return is_global, side, kind
+ return (
+ is_global,
+ self._match_set(self.JOIN_SIDES) and self._prev,
+ self._match_set(self.JOIN_KINDS) and self._prev,
+ )
+
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ join = super()._parse_join(skip_join_token)
+
+ if join:
+ join.set("global", join.args.pop("natural", None))
+ return join
+
+ def _parse_function(
+ self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
+ ) -> t.Optional[exp.Expression]:
+ func = super()._parse_function(functions, anonymous)
+
+ if isinstance(func, exp.Anonymous):
+ params = self._parse_func_params(func)
+
+ if params:
+ return self.expression(
+ exp.ParameterizedAgg,
+ this=func.this,
+ expressions=func.expressions,
+ params=params,
+ )
+
+ return func
+
+ def _parse_func_params(
+ self, this: t.Optional[exp.Func] = None
+ ) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
+ return self._parse_csv(self._parse_lambda)
+ if self._match(TokenType.L_PAREN):
+ params = self._parse_csv(self._parse_lambda)
+ self._match_r_paren(this)
+ return params
+ return None
+
+ def _parse_quantile(self) -> exp.Quantile:
+ this = self._parse_lambda()
+ params = self._parse_func_params()
+ if params:
+ return self.expression(exp.Quantile, this=params[0], quantile=this)
+ return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
+
+ def _parse_wrapped_id_vars(
+ self, optional: bool = False
+ ) -> t.List[t.Optional[exp.Expression]]:
+ return super()._parse_wrapped_id_vars(optional=True)
+
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.NULLABLE: "Nullable",
- exp.DataType.Type.DATETIME: "DateTime64",
- exp.DataType.Type.MAP: "Map",
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
+ exp.DataType.Type.BIGINT: "Int64",
+ exp.DataType.Type.DATETIME64: "DateTime64",
+ exp.DataType.Type.DOUBLE: "Float64",
+ exp.DataType.Type.FLOAT: "Float32",
+ exp.DataType.Type.INT: "Int32",
+ exp.DataType.Type.INT128: "Int128",
+ exp.DataType.Type.INT256: "Int256",
+ exp.DataType.Type.MAP: "Map",
+ exp.DataType.Type.NULLABLE: "Nullable",
+ exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
- exp.DataType.Type.UTINYINT: "UInt8",
- exp.DataType.Type.SMALLINT: "Int16",
- exp.DataType.Type.USMALLINT: "UInt16",
- exp.DataType.Type.INT: "Int32",
- exp.DataType.Type.UINT: "UInt32",
- exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.UBIGINT: "UInt64",
- exp.DataType.Type.INT128: "Int128",
+ exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.UINT128: "UInt128",
- exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.UINT256: "UInt256",
- exp.DataType.Type.FLOAT: "Float32",
- exp.DataType.Type.DOUBLE: "Float64",
+ exp.DataType.Type.USMALLINT: "UInt16",
+ exp.DataType.Type.UTINYINT: "UInt8",
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
+ exp.AnyValue: rename_func("any"),
+ exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
- exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
+ exp.CastToStrType: rename_func("CAST"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
- exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
- exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
- exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
- exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
- exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.Pivot: no_pivot_sql,
+ exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
+ + f"({self.sql(e, 'this')})",
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
}
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
-
- def _param_args_sql(
- self,
- expression: exp.Expression,
- param_names: str | t.List[str],
- arg_names: str | t.List[str],
- ) -> str:
- params = self.format_args(
- *(
- arg
- for name in ensure_list(param_names)
- for arg in ensure_list(expression.args.get(name))
- )
- )
- args = self.format_args(
- *(
- arg
- for name in ensure_list(arg_names)
- for arg in ensure_list(expression.args.get(name))
- )
- )
- return f"({params})({args})"
+ GROUPINGS_SEP = ""
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
return super().cte_sql(expression)
+
+ def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
+ return super().after_limit_modifiers(expression) + [
+ self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
+ if expression.args.get("settings")
+ else "",
+ self.seg("FORMAT ") + self.sql(expression, "format")
+ if expression.args.get("format")
+ else "",
+ ]
+
+ def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
+ params = self.expressions(expression, "params", flat=True)
+ return self.func(expression.name, *expression.expressions) + f"({params})"
+
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
+ return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 51112a0..2149aca 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -25,7 +25,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
- **Spark.Generator.TRANSFORMS, # type: ignore
+ **Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 71269f2..890a3c3 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -8,10 +8,16 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
-from sqlglot.tokens import Token, Tokenizer
+from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
-E = t.TypeVar("E", bound=exp.Expression)
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
+
+# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
+# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
+RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
class Dialects(str, Enum):
@@ -42,6 +48,19 @@ class Dialects(str, Enum):
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
+ def __eq__(cls, other: t.Any) -> bool:
+ if cls is other:
+ return True
+ if isinstance(other, str):
+ return cls is cls.get(other)
+ if isinstance(other, Dialect):
+ return cls is type(other)
+
+ return False
+
+ def __hash__(cls) -> int:
+ return hash(cls.__name__.lower())
+
@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@@ -70,17 +89,20 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
- klass.bit_start, klass.bit_end = seq_get(
- list(klass.tokenizer_class._BIT_STRINGS.items()), 0
- ) or (None, None)
-
- klass.hex_start, klass.hex_end = seq_get(
- list(klass.tokenizer_class._HEX_STRINGS.items()), 0
- ) or (None, None)
+ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
+ return next(
+ (
+ (s, e)
+ for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
+ if t == token_type
+ ),
+ (None, None),
+ )
- klass.byte_start, klass.byte_end = seq_get(
- list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
- ) or (None, None)
+ klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
+ klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
+ klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
+ klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
return klass
@@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
parser_class = None
generator_class = None
+ def __eq__(self, other: t.Any) -> bool:
+ return type(self) == other
+
+ def __hash__(self) -> int:
+ return hash(type(self))
+
@classmethod
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
@@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
+ "raw_start": self.raw_start,
+ "raw_end": self.raw_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
@@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
- return self.sql(expression)
+ return ""
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
@@ -328,7 +358,7 @@ def var_map_sql(
def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
-) -> t.Callable[[t.Sequence], E]:
+) -> t.Callable[[t.List], E]:
"""Helper used for time expressions.
Args:
@@ -340,7 +370,7 @@ def format_time_lambda(
A callable that can be used to return the appropriately formatted time expression.
"""
- def _format_time(args: t.Sequence):
+ def _format_time(args: t.List):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
@@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
-) -> t.Callable[[t.Sequence], E]:
- def inner_func(args: t.Sequence) -> E:
+) -> t.Callable[[t.List], E]:
+ def inner_func(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
- unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
+ unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
@@ -390,8 +420,8 @@ def parse_date_delta(
def parse_date_delta_with_interval(
expression_class: t.Type[E],
-) -> t.Callable[[t.Sequence], t.Optional[E]]:
- def func(args: t.Sequence) -> t.Optional[E]:
+) -> t.Callable[[t.List], t.Optional[E]]:
+ def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
return None
@@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
return func
-def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
+def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
@@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
-def locate_to_strposition(args: t.Sequence) -> exp.Expression:
+def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
-def str_to_time_sql(self, expression: exp.Expression) -> str:
+def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
@@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
+
+
+# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
+def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
+ names = []
+ for agg in aggregations:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
+
+ return names
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 7ad555e..924b979 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -95,7 +95,7 @@ class Drill(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
@@ -108,7 +108,7 @@ class Drill(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
@@ -121,13 +121,13 @@ class Drill(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index bce956e..662882d 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -11,9 +11,9 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
no_comment_column_constraint_sql,
- no_pivot_sql,
no_properties_sql,
no_safe_divide_sql,
+ pivot_column_names,
rename_func,
str_position_sql,
str_to_time_sql,
@@ -31,10 +31,11 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
-def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
+def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
+ op = "+" if isinstance(expression, exp.DateAdd) else "-"
+ return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
@@ -50,11 +51,11 @@ def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str
return f"ARRAY_SORT({this})"
-def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
+def _sort_array_reverse(args: t.List) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
-def _parse_date_diff(args: t.Sequence) -> exp.Expression:
+def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@@ -89,11 +90,14 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
+ null_ordering = "nulls_are_last"
+
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~": TokenType.RLIKE,
":=": TokenType.EQ,
+ "//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
"BPCHAR": TokenType.TEXT,
@@ -104,6 +108,7 @@ class DuckDB(Dialect):
"INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
"NUMERIC": TokenType.DOUBLE,
+ "PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
@@ -114,8 +119,7 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
- "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ **parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
@@ -152,11 +156,17 @@ class DuckDB(Dialect):
TokenType.UTINYINT,
}
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
+ if len(aggregations) == 1:
+ return super()._pivot_column_names(aggregations)
+ return pivot_column_names(aggregations, dialect="duckdb")
+
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
+ RENAME_TABLE_WITH_DB = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -175,7 +185,8 @@ class DuckDB(Dialect):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
- exp.DateAdd: _date_add_sql,
+ exp.DateAdd: _date_delta_sql,
+ exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
@@ -183,13 +194,13 @@ class DuckDB(Dialect):
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
+ exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
- exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpExtract: _regexp_extract_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@@ -232,11 +243,11 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 871a180..fbd626a 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
-def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
- this = self.sql(expression, "this")
- table = self.sql(expression, "table")
- columns = self.sql(expression, "columns")
- return f"{this} ON TABLE {table} {columns}"
-
-
class Hive(Dialect):
alias_post_tablesample = True
@@ -225,8 +218,7 @@ class Hive(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
- "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ **parser.Parser.FUNCTIONS,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
@@ -271,21 +263,29 @@ class Hive(Dialect):
}
PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS, # type: ignore
+ **parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ "distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
+ "sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
+ "cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
+ }
+
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
+ INDEX_ON = "ON TABLE"
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
@@ -294,7 +294,7 @@ class Hive(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[
@@ -319,7 +319,6 @@ class Hive(Dialect):
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
- exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
@@ -342,7 +341,6 @@ class Hive(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@@ -363,14 +361,13 @@ class Hive(Dialect):
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
- exp.National: lambda self, e: self.sql(e, "this"),
+ exp.National: lambda self, e: self.national_sql(e, prefix=""),
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
- exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@@ -396,3 +393,10 @@ class Hive(Dialect):
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)
+
+ def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
+ return super().after_having_modifiers(expression) + [
+ self.sql(expression, "distribute"),
+ self.sql(expression, "sort"),
+ self.sql(expression, "cluster"),
+ ]
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5342624..2b41860 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
no_paren_current_date_sql,
+ no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
@@ -21,14 +24,14 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _show_parser(*args, **kwargs):
- def _parse(self):
+def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]:
+ def _parse(self: MySQL.Parser) -> exp.Show:
return self._parse_show_mysql(*args, **kwargs)
return _parse
-def _date_trunc_sql(self, expression):
+def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
@@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression):
return f"STR_TO_DATE({concat}, '{date_format}')"
-def _str_to_date(args):
+def _str_to_date(args: t.List) -> exp.StrToDate:
date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self, expression):
+def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
-def _trim_sql(self, expression):
+def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
@@ -79,8 +82,8 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind):
- def func(self, expression):
+def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return (
@@ -175,10 +178,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
- FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
@@ -191,7 +194,7 @@ class MySQL(Dialect):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -199,13 +202,8 @@ class MySQL(Dialect):
),
}
- PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS, # type: ignore
- "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
- }
-
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
}
@@ -286,7 +284,13 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
- def _parse_show_mysql(self, this, target=False, full=None, global_=None):
+ def _parse_show_mysql(
+ self,
+ this: str,
+ target: bool | str = False,
+ full: t.Optional[bool] = None,
+ global_: t.Optional[bool] = None,
+ ) -> exp.Show:
if target:
if isinstance(target, str):
self._match_text_seq(target)
@@ -342,10 +346,12 @@ class MySQL(Dialect):
offset=offset,
limit=limit,
mutex=mutex,
- **{"global": global_},
+ **{"global": global_}, # type: ignore
)
- def _parse_oldstyle_limit(self):
+ def _parse_oldstyle_limit(
+ self,
+ ) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]:
limit = None
offset = None
if self._match_text_seq("LIMIT"):
@@ -355,23 +361,20 @@ class MySQL(Dialect):
elif len(parts) == 2:
limit = parts[1]
offset = parts[0]
+
return offset, limit
- def _parse_set_item_charset(self, kind):
+ def _parse_set_item_charset(self, kind: str) -> exp.Expression:
this = self._parse_string() or self._parse_id_var()
+ return self.expression(exp.SetItem, this=this, kind=kind)
- return self.expression(
- exp.SetItem,
- this=this,
- kind=kind,
- )
-
- def _parse_set_item_names(self):
+ def _parse_set_item_names(self) -> exp.Expression:
charset = self._parse_string() or self._parse_id_var()
if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
+
return self.expression(
exp.SetItem,
this=charset,
@@ -386,7 +389,7 @@ class MySQL(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
@@ -403,6 +406,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@@ -422,7 +426,7 @@ class MySQL(Dialect):
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index c8af1c6..7722753 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _parse_xml_table(self) -> exp.XMLTable:
+def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
this = self._parse_string()
passing = None
@@ -66,7 +66,7 @@ class Oracle(Dialect):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
}
@@ -107,7 +107,7 @@ class Oracle(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@@ -122,7 +122,7 @@ class Oracle(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
@@ -143,7 +143,7 @@ class Oracle(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 2132778..ab61880 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
no_paren_current_date_sql,
+ no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
@@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = {
}
-def _date_add_sql(kind):
- def func(self, expression):
+def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
@@ -51,7 +52,7 @@ def _date_add_sql(kind):
return func
-def _date_diff_sql(self, expression):
+def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
@@ -77,7 +78,7 @@ def _date_diff_sql(self, expression):
return f"CAST({unit} AS BIGINT)"
-def _substring_sql(self, expression):
+def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
@@ -88,7 +89,7 @@ def _substring_sql(self, expression):
return f"SUBSTRING({this}{from_part}{for_part})"
-def _string_agg_sql(self, expression):
+def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@@ -102,13 +103,13 @@ def _string_agg_sql(self, expression):
return f"STRING_AGG({self.format_args(this, separator)}{order})"
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
-def _auto_increment_to_serial(expression):
+def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
@@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression):
return expression
-def _serial_to_generated(expression):
+def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
kind = expression.args["kind"]
if kind.this == exp.DataType.Type.SERIAL:
@@ -144,6 +145,7 @@ def _serial_to_generated(expression):
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
+
if notnull not in constraints:
constraints.insert(0, notnull)
if generated not in constraints:
@@ -152,7 +154,7 @@ def _serial_to_generated(expression):
return expression
-def _generate_series(args):
+def _generate_series(args: t.List) -> exp.Expression:
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
@@ -168,11 +170,12 @@ def _generate_series(args):
return exp.GenerateSeries.from_arg_list(args)
-def _to_timestamp(args):
+def _to_timestamp(args: t.List) -> exp.Expression:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
+
# https://www.postgresql.org/docs/current/functions-formatting.html
return format_time_lambda(exp.StrToTime, "postgres")(args)
@@ -255,7 +258,7 @@ class Postgres(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
@@ -271,7 +274,7 @@ class Postgres(Dialect):
}
BITWISE = {
- **parser.Parser.BITWISE, # type: ignore
+ **parser.Parser.BITWISE,
TokenType.HASH: exp.BitwiseXor,
}
@@ -280,7 +283,7 @@ class Postgres(Dialect):
}
RANGE_PARSERS = {
- **parser.Parser.RANGE_PARSERS, # type: ignore
+ **parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
@@ -303,14 +306,14 @@ class Postgres(Dialect):
return self.expression(exp.Extract, this=part, expression=value)
class Generator(generator.Generator):
- INTERVAL_ALLOWS_PLURAL_FORM = False
+ SINGLE_STRING_INTERVAL = True
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@@ -320,14 +323,9 @@ class Postgres(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
- exp.ColumnDef: transforms.preprocess(
- [
- _auto_increment_to_serial,
- _serial_to_generated,
- ],
- ),
+ exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
@@ -348,6 +346,7 @@ class Postgres(Dialect):
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
+ exp.Pivot: no_pivot_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
@@ -369,7 +368,7 @@ class Postgres(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 6133a27..52a04a4 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
no_ilike_sql,
+ no_pivot_sql,
no_safe_divide_sql,
rename_func,
struct_extract_sql,
@@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)
-def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
- start = expression.args["start"]
- end = expression.args["end"]
- step = expression.args.get("step")
-
- target_type = None
-
- if isinstance(start, exp.Cast):
- target_type = start.to
- elif isinstance(end, exp.Cast):
- target_type = end.to
-
- if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
- to = target_type.copy()
-
- if target_type is start.to:
- end = exp.Cast(this=end, to=to)
- else:
- start = exp.Cast(this=start, to=to)
-
- sql = self.func("SEQUENCE", start, end, step)
- if isinstance(expression.parent, exp.Table):
- sql = f"UNNEST({sql})"
-
- return sql
-
-
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
-def _approx_percentile(args: t.Sequence) -> exp.Expression:
+def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression:
return exp.ApproxQuantile.from_arg_list(args)
-def _from_unixtime(args: t.Sequence) -> exp.Expression:
+def _from_unixtime(args: t.List) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
+def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Table):
+ if isinstance(expression.this, exp.GenerateSeries):
+ unnest = exp.Unnest(expressions=[expression.this])
+
+ if expression.alias:
+ return exp.alias_(
+ unnest,
+ alias="_u",
+ table=[expression.alias],
+ copy=False,
+ )
+ return unnest
+ return expression
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
- time_format = MySQL.time_format # type: ignore
- time_mapping = MySQL.time_mapping # type: ignore
+ time_format = MySQL.time_format
+ time_mapping = MySQL.time_mapping
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
+ "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"ROW": TokenType.STRUCT,
}
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
@@ -252,13 +243,13 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@@ -268,8 +259,9 @@ class Presto(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: _approx_distinct_sql,
+ exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
@@ -293,7 +285,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
- exp.GenerateSeries: _sequence_sql,
+ exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
@@ -301,10 +293,10 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
- exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
- exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@@ -320,8 +312,7 @@ class Presto(Dialect):
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
- exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
- exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
+ exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
@@ -336,6 +327,7 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
+ exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
@@ -351,3 +343,25 @@ class Presto(Dialect):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
+
+ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
+ start = expression.args["start"]
+ end = expression.args["end"]
+ step = expression.args.get("step")
+
+ if isinstance(start, exp.Cast):
+ target_type = start.to
+ elif isinstance(end, exp.Cast):
+ target_type = end.to
+ else:
+ target_type = None
+
+ if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
+ to = target_type.copy()
+
+ if target_type is start.to:
+ end = exp.Cast(this=end, to=to)
+ else:
+ start = exp.Cast(this=start, to=to)
+
+ return self.func("SEQUENCE", start, end, step)
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 1b7cf31..55e393a 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -8,21 +8,21 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _json_sql(self, e) -> str:
- return f'{self.sql(e, "this")}."{e.expression.name}"'
+def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
+ return f'{self.sql(expression, "this")}."{expression.expression.name}"'
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
- **Postgres.time_mapping, # type: ignore
+ **Postgres.time_mapping,
"MON": "%b",
"HH": "%H",
}
class Parser(Postgres.Parser):
FUNCTIONS = {
- **Postgres.Parser.FUNCTIONS, # type: ignore
+ **Postgres.Parser.FUNCTIONS,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@@ -45,7 +45,7 @@ class Redshift(Postgres):
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.expressions
- and this.expressions[0] == exp.column("MAX")
+ and this.expressions[0].this == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
@@ -57,9 +57,7 @@ class Redshift(Postgres):
STRING_ESCAPES = ["\\"]
KEYWORDS = {
- **Postgres.Tokenizer.KEYWORDS, # type: ignore
- "GEOMETRY": TokenType.GEOMETRY,
- "GEOGRAPHY": TokenType.GEOGRAPHY,
+ **Postgres.Tokenizer.KEYWORDS,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
@@ -76,22 +74,22 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
- SINGLE_STRING_INTERVAL = True
+ RENAME_TABLE_WITH_DB = False
TYPE_MAPPING = {
- **Postgres.Generator.TYPE_MAPPING, # type: ignore
+ **Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
PROPERTIES_LOCATION = {
- **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
+ **Postgres.Generator.PROPERTIES_LOCATION,
exp.LikeProperty: exp.Properties.Location.POST_WITH,
}
TRANSFORMS = {
- **Postgres.Generator.TRANSFORMS, # type: ignore
+ **Postgres.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@@ -107,10 +105,13 @@ class Redshift(Postgres):
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
+ # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
+ TRANSFORMS.pop(exp.Pivot)
+
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
- RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
+ RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def values_sql(self, expression: exp.Values) -> str:
"""
@@ -120,37 +121,36 @@ class Redshift(Postgres):
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
- if not isinstance(expression.unnest().parent, exp.From):
+
+ # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
+ if not expression.find_ancestor(exp.From, exp.Join):
return super().values_sql(expression)
- rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
+
+ column_names = expression.alias and expression.args["alias"].columns
+
selects = []
+ rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
+
for i, row in enumerate(rows):
- if i == 0 and expression.alias:
+ if i == 0 and column_names:
row = [
exp.alias_(value, column_name)
- for value, column_name in zip(row, expression.args["alias"].args["columns"])
+ for value, column_name in zip(row, column_names)
]
+
selects.append(exp.Select(expressions=row))
- subquery_expression = selects[0]
+
+ subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
+
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
- def renametable_sql(self, expression: exp.RenameTable) -> str:
- """Redshift only supports defining the table name itself (not the db) when renaming tables"""
- expression = expression.copy()
- target_table = expression.this
- for arg in target_table.args:
- if arg != "this":
- target_table.set(arg, None)
- this = self.sql(expression, "this")
- return f"RENAME TO {this}"
-
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
@@ -162,6 +162,8 @@ class Redshift(Postgres):
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
+
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
+
return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 70dcaa9..756e8e9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import flatten, seq_get
+from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
+def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime(this=first_arg, scale=timescale)
+ from sqlglot.optimizer.simplify import simplify_literals
+
+ # The first argument might be an expression like 40 * 365 * 86400, so we try to
+ # reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
- if not isinstance(first_arg, Literal):
+ if not isinstance(simplify_literals(first_arg, root=True), Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix
return exp.UnixToTime.from_arg_list(args)
+def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
+ expression = parser.parse_var_map(args)
+
+ if isinstance(expression, exp.StarMap):
+ return expression
+
+ return exp.Struct(
+ expressions=[
+ t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
+ ]
+ )
+
+
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
@@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args: t.Sequence) -> exp.Expression:
+def _div0_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
+def _zeroifnull_to_if(args: t.List) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
+def _nullifzero_to_if(args: t.List) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
+def _parse_convert_timezone(args: t.List) -> exp.Expression:
+ if len(args) == 3:
+ return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
+ return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
+
+
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@@ -177,17 +200,14 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
- QUOTED_PIVOT_COLUMNS = True
+ IDENTIFY_PIVOT_STRINGS = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
- "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone(
- this=seq_get(args, 1),
- zone=seq_get(args, 0),
- ),
+ "CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
@@ -202,7 +222,7 @@ class Snowflake(Dialect):
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
- "OBJECT_CONSTRUCT": parser.parse_var_map,
+ "OBJECT_CONSTRUCT": _parse_object_construct,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_ARRAY": exp.Array.from_arg_list,
@@ -224,7 +244,7 @@ class Snowflake(Dialect):
}
COLUMN_OPERATORS = {
- **parser.Parser.COLUMN_OPERATORS, # type: ignore
+ **parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@@ -232,14 +252,16 @@ class Snowflake(Dialect):
),
}
+ TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME}
+
RANGE_PARSERS = {
- **parser.Parser.RANGE_PARSERS, # type: ignore
+ **parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
- **parser.Parser.ALTER_PARSERS, # type: ignore
+ **parser.Parser.ALTER_PARSERS,
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
"SET": lambda self: self._parse_alter_table_set_tag(),
}
@@ -256,17 +278,20 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "CHAR VARYING": TokenType.VARCHAR,
+ "CHARACTER VARYING": TokenType.VARCHAR,
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
+ "MINUS": TokenType.EXCEPT,
+ "NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
- "MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
@@ -285,7 +310,7 @@ class Snowflake(Dialect):
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
@@ -299,6 +324,7 @@ class Snowflake(Dialect):
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.Extract: rename_func("DATE_PART"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@@ -312,6 +338,10 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Struct: lambda self, e: self.func(
+ "OBJECT_CONSTRUCT",
+ *(arg for expression in e.expressions for arg in expression.flatten()),
+ ),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.TimeToStr: lambda self, e: self.func(
@@ -326,7 +356,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -336,7 +366,7 @@ class Snowflake(Dialect):
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@@ -351,53 +381,10 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
- def values_sql(self, expression: exp.Values) -> str:
- """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
-
- We also want to make sure that after we find matches where we need to unquote a column that we prevent users
- from adding quotes to the column by using the `identify` argument when generating the SQL.
- """
- alias = expression.args.get("alias")
- if alias and alias.args.get("columns"):
- expression = expression.transform(
- lambda node: exp.Identifier(**{**node.args, "quoted": False})
- if isinstance(node, exp.Identifier)
- and isinstance(node.parent, exp.TableAlias)
- and node.arg_key == "columns"
- else node,
- )
- return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
- return super().values_sql(expression)
-
def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"
- def select_sql(self, expression: exp.Select) -> str:
- """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
- that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
- to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
- generating the SQL.
-
- Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
- expression. This might not be true in a case where the same column name can be sourced from another table that can
- properly quote but should be true in most cases.
- """
- values_identifiers = set(
- flatten(
- (v.args.get("alias") or exp.Alias()).args.get("columns", [])
- for v in expression.find_all(exp.Values)
- )
- )
- if values_identifiers:
- expression = expression.transform(
- lambda node: exp.Identifier(**{**node.args, "quoted": False})
- if isinstance(node, exp.Identifier) and node in values_identifiers
- else node,
- )
- return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
- return super().select_sql(expression)
-
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 939f2fd..b7d1641 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -7,10 +7,10 @@ from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
-def _parse_datediff(args: t.Sequence) -> exp.Expression:
+def _parse_datediff(args: t.List) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
- it at some point. Databricks also supports this variation (see below).
+ it at some point. Databricks also supports this variant (see below).
For example, in spark-sql (v3.3.1):
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
@@ -36,7 +36,7 @@ def _parse_datediff(args: t.Sequence) -> exp.Expression:
class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
- **Spark2.Parser.FUNCTIONS, # type: ignore
+ **Spark2.Parser.FUNCTIONS,
"DATEDIFF": _parse_datediff,
}
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 584671f..912b86b 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -3,7 +3,12 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parser, transforms
-from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
+from sqlglot.dialects.dialect import (
+ create_with_partitions_sql,
+ pivot_column_names,
+ rename_func,
+ trim_sql,
+)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
return f"MAP_FROM_ARRAYS({keys}, {values})"
-def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
+def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
@@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
raise ValueError("Improper scale for timestamp")
+def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
+ """
+ Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
+ pivoted source in a subquery with the same alias to preserve the query's semantics.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
+ >>> print(_unalias_pivot(expr).sql(dialect="spark"))
+ SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
+ """
+ if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
+ pivot = expression.this.args["pivots"][0]
+ if pivot.alias:
+ alias = pivot.args["alias"].pop()
+ return exp.From(
+ this=expression.this.replace(
+ exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
+ )
+ )
+
+ return expression
+
+
+def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
+ so we need to unqualify it.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
+ >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
+ SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
+ """
+ if isinstance(expression, exp.Pivot):
+ expression.args["field"].transform(
+ lambda node: exp.column(node.output_name, quoted=node.this.quoted)
+ if isinstance(node, exp.Column)
+ else node,
+ copy=False,
+ )
+
+ return expression
+
+
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS, # type: ignore
+ **Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
@@ -110,7 +161,7 @@ class Spark2(Hive):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -131,43 +182,21 @@ class Spark2(Hive):
kind="COLUMNS",
)
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
- if len(pivot_columns) == 1:
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
+ if len(aggregations) == 1:
return [""]
-
- names = []
- for agg in pivot_columns:
- if isinstance(agg, exp.Alias):
- names.append(agg.alias)
- else:
- """
- This case corresponds to aggregations without aliases being used as suffixes
- (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
- be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
- Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
-
- Moreover, function names are lowercased in order to mimic Spark's naming scheme.
- """
- agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
- )
- names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
-
- return names
+ return pivot_column_names(aggregations, dialect="spark")
class Generator(Hive.Generator):
TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING, # type: ignore
+ **Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
- **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
+ **Hive.Generator.PROPERTIES_LOCATION,
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
@@ -175,7 +204,7 @@ class Spark2(Hive):
}
TRANSFORMS = {
- **Hive.Generator.TRANSFORMS, # type: ignore
+ **Hive.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
@@ -188,11 +217,12 @@ class Spark2(Hive):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
+ exp.From: transforms.preprocess([_unalias_pivot]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
- exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
+ exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index f2efe32..56e7773 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
count_if_to_sum,
no_ilike_sql,
+ no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
@@ -14,7 +15,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
-def _date_add_sql(self, expression):
+def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
modifier = expression.expression
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
@@ -67,7 +68,7 @@ class SQLite(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
@@ -76,7 +77,7 @@ class SQLite(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@@ -98,7 +99,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
@@ -114,6 +115,7 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
+ exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
@@ -163,12 +165,15 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
# https://www.sqlite.org/lang_aggfunc.html#group_concat
- def groupconcat_sql(self, expression):
+ def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
this = expression.this
distinct = expression.find(exp.Distinct)
+
if distinct:
this = distinct.expressions[0]
- distinct = "DISTINCT "
+ distinct_sql = "DISTINCT "
+ else:
+ distinct_sql = ""
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
@@ -176,7 +181,7 @@ class SQLite(Dialect):
this = expression.this.this
separator = expression.args.get("separator")
- return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
+ return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})"
def least_sql(self, expression: exp.Least) -> str:
if len(expression.expressions) > 1:
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 895588a..0390113 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -11,25 +11,24 @@ from sqlglot.helper import seq_get
class StarRocks(MySQL):
- class Parser(MySQL.Parser): # type: ignore
+ class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
- "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
}
- class Generator(MySQL.Generator): # type: ignore
+ class Generator(MySQL.Generator):
TYPE_MAPPING = {
- **MySQL.Generator.TYPE_MAPPING, # type: ignore
+ **MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
TRANSFORMS = {
- **MySQL.Generator.TRANSFORMS, # type: ignore
+ **MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
@@ -43,4 +42,5 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
+
TRANSFORMS.pop(exp.DateTrunc)
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 51e685b..d5fba17 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -4,41 +4,38 @@ from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
-def _if_sql(self, expression):
- return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END"
-
-
-def _coalesce_sql(self, expression):
- return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
-
-
-def _count_sql(self, expression):
- this = expression.this
- if isinstance(this, exp.Distinct):
- return f"COUNTD({self.expressions(this, flat=True)})"
- return f"COUNT({self.sql(expression, 'this')})"
-
-
class Tableau(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
- exp.If: _if_sql,
- exp.Coalesce: _coalesce_sql,
- exp.Count: _count_sql,
+ **generator.Generator.TRANSFORMS,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def if_sql(self, expression: exp.If) -> str:
+ this = self.sql(expression, "this")
+ true = self.sql(expression, "true")
+ false = self.sql(expression, "false")
+ return f"IF {this} THEN {true} ELSE {false} END"
+
+ def coalesce_sql(self, expression: exp.Coalesce) -> str:
+ return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
+
+ def count_sql(self, expression: exp.Count) -> str:
+ this = expression.this
+ if isinstance(this, exp.Distinct):
+ return f"COUNTD({self.expressions(this, flat=True)})"
+ return f"COUNT({self.sql(expression, 'this')})"
+
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index a79eaeb..9b39178 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -75,12 +75,12 @@ class Teradata(Dialect):
FUNC_TOKENS.remove(TokenType.REPLACE)
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.REPLACE: lambda self: self._parse_create(),
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ **parser.Parser.FUNCTION_PARSERS,
"RANGE_N": lambda self: self._parse_rangen(),
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
@@ -106,7 +106,7 @@ class Teradata(Dialect):
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
- "from": self._parse_from(),
+ "from": self._parse_from(modifiers=True),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
@@ -135,13 +135,15 @@ class Teradata(Dialect):
TABLE_HINTS = False
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
+ **generator.Generator.PROPERTIES_LOCATION,
+ exp.OnCommitProperty: exp.Properties.Location.POST_INDEX,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION,
+ exp.StabilityProperty: exp.Properties.Location.POST_CREATE,
}
TRANSFORMS = {
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index c7b34fe..af0f78d 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -7,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
- **Presto.Generator.TRANSFORMS, # type: ignore
+ **Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 03de99c..f6ad888 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -16,6 +16,9 @@ from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
@@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
-def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
- def _format_time(args):
+def _format_time_lambda(
+ exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
+) -> t.Callable[[t.List], E]:
+ def _format_time(args: t.List) -> E:
+ assert len(args) == 2
+
return exp_class(
- this=seq_get(args, 1),
+ this=args[1],
format=exp.Literal.string(
format_time(
- seq_get(args, 0).name or (TSQL.time_format if default is True else default),
+ args[0].name,
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
@@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
-def _parse_format(args):
- fmt = seq_get(args, 1)
- number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
+def _parse_format(args: t.List) -> exp.Expression:
+ assert len(args) == 2
+
+ fmt = args[1]
+ number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
+
if number_fmt:
- return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
+ return exp.NumberToStr(this=args[0], format=fmt)
+
return exp.TimeToStr(
- this=seq_get(args, 0),
+ this=args[0],
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@@ -82,7 +93,7 @@ def _parse_format(args):
)
-def _parse_eomonth(args):
+def _parse_eomonth(args: t.List) -> exp.Expression:
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
@@ -96,7 +107,7 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
-def _parse_hashbytes(args):
+def _parse_hashbytes(args: t.List) -> exp.Expression:
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
@@ -110,40 +121,47 @@ def _parse_hashbytes(args):
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
+
return exp.func("HASHBYTES", *args)
-def generate_date_delta_with_unit_sql(self, e):
- func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
- return self.func(func, e.text("unit"), e.expression, e.this)
+def generate_date_delta_with_unit_sql(
+ self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
+) -> str:
+ func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
+ return self.func(func, expression.text("unit"), expression.expression, expression.this)
-def _format_sql(self, e):
+def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
- e.args["format"]
- if isinstance(e, exp.NumberToStr)
- else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
+ expression.args["format"]
+ if isinstance(expression, exp.NumberToStr)
+ else exp.Literal.string(
+ format_time(
+ expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
+ )
+ )
)
- return self.func("FORMAT", e.this, fmt)
+ return self.func("FORMAT", expression.this, fmt)
-def _string_agg_sql(self, e):
- e = e.copy()
+def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+ expression = expression.copy()
- this = e.this
- distinct = e.find(exp.Distinct)
+ this = expression.this
+ distinct = expression.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.pop().expressions[0]
order = ""
- if isinstance(e.this, exp.Order):
- if e.this.this:
- this = e.this.this.pop()
- order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
+ if isinstance(expression.this, exp.Order):
+ if expression.this.this:
+ this = expression.this.this.pop()
+ order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space
- separator = e.args.get("separator") or exp.Literal.string(",")
+ separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
@@ -292,7 +310,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -332,13 +350,13 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
- RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
+ RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
- *parser.Parser.TYPE_TOKENS, # type: ignore
+ *parser.Parser.TYPE_TOKENS,
}
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_command(),
}
@@ -377,7 +395,7 @@ class TSQL(Dialect):
return system_time
- def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
@@ -450,7 +468,7 @@ class TSQL(Dialect):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
@@ -458,7 +476,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@@ -480,7 +498,7 @@ class TSQL(Dialect):
TRANSFORMS.pop(exp.ReturnsProperty)
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}