summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py9
-rw-r--r--sqlglot/__main__.py2
-rw-r--r--sqlglot/dataframe/sql/dataframe.py39
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dataframe/sql/session.py4
-rw-r--r--sqlglot/dialects/bigquery.py117
-rw-r--r--sqlglot/dialects/clickhouse.py36
-rw-r--r--sqlglot/dialects/databricks.py13
-rw-r--r--sqlglot/dialects/dialect.py220
-rw-r--r--sqlglot/dialects/doris.py3
-rw-r--r--sqlglot/dialects/drill.py3
-rw-r--r--sqlglot/dialects/duckdb.py71
-rw-r--r--sqlglot/dialects/hive.py68
-rw-r--r--sqlglot/dialects/mysql.py29
-rw-r--r--sqlglot/dialects/oracle.py33
-rw-r--r--sqlglot/dialects/postgres.py97
-rw-r--r--sqlglot/dialects/presto.py72
-rw-r--r--sqlglot/dialects/redshift.py50
-rw-r--r--sqlglot/dialects/snowflake.py115
-rw-r--r--sqlglot/dialects/spark.py8
-rw-r--r--sqlglot/dialects/spark2.py24
-rw-r--r--sqlglot/dialects/sqlite.py6
-rw-r--r--sqlglot/dialects/teradata.py13
-rw-r--r--sqlglot/dialects/tsql.py118
-rw-r--r--sqlglot/executor/env.py12
-rw-r--r--sqlglot/executor/python.py18
-rw-r--r--sqlglot/executor/table.py8
-rw-r--r--sqlglot/expressions.py360
-rw-r--r--sqlglot/generator.py344
-rw-r--r--sqlglot/helper.py27
-rw-r--r--sqlglot/lineage.py105
-rw-r--r--sqlglot/optimizer/annotate_types.py110
-rw-r--r--sqlglot/optimizer/canonicalize.py85
-rw-r--r--sqlglot/optimizer/merge_subqueries.py4
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py6
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py47
-rw-r--r--sqlglot/optimizer/qualify_tables.py15
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py73
-rw-r--r--sqlglot/parser.py297
-rw-r--r--sqlglot/schema.py48
-rw-r--r--sqlglot/time.py4
-rw-r--r--sqlglot/tokens.py66
-rw-r--r--sqlglot/transforms.py62
45 files changed, 1983 insertions, 866 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 35feaad..6cf9949 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -22,6 +22,7 @@ from sqlglot.expressions import (
Expression as Expression,
alias_ as alias,
and_ as and_,
+ case as case,
cast as cast,
column as column,
condition as condition,
@@ -82,8 +83,7 @@ def parse(
Returns:
The resulting syntax tree collection.
"""
- dialect = Dialect.get_or_raise(read or dialect)()
- return dialect.parse(sql, **opts)
+ return Dialect.get_or_raise(read or dialect).parse(sql, **opts)
@t.overload
@@ -117,7 +117,7 @@ def parse_one(
The syntax tree for the first parsed statement.
"""
- dialect = Dialect.get_or_raise(read or dialect)()
+ dialect = Dialect.get_or_raise(read or dialect)
if into:
result = dialect.parse_into(into, sql, **opts)
@@ -157,7 +157,8 @@ def transpile(
The list of transpiled SQL statements.
"""
write = (read if write is None else write) if identity else write
+ write = Dialect.get_or_raise(write)
return [
- Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else ""
+ write.generate(expression, copy=False, **opts) if expression else ""
for expression in parse(sql, read, error_level=error_level)
]
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index 4a2820b..5a77409 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -81,7 +81,7 @@ if args.parse:
)
]
elif args.tokenize:
- objs = sqlglot.Dialect.get_or_raise(args.read)().tokenize(sql)
+ objs = sqlglot.Dialect.get_or_raise(args.read).tokenize(sql)
else:
objs = sqlglot.transpile(
sql,
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index f515608..68d36fe 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -297,27 +297,26 @@ class DataFrame:
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
- def sql(
- self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
- ) -> t.List[str]:
+ def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
from sqlglot.dataframe.sql.session import SparkSession
- if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
- logger.warning(
- f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
- )
+ dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
+
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
+
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- quote_identifiers(select_expression)
+ quote_identifiers(select_expression, dialect=dialect)
select_expression = t.cast(
- exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
+ exp.Select, optimize_func(select_expression, dialect=dialect)
)
+
select_expression = df._replace_cte_names_with_hashes(select_expression)
+
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
cache_table_name = df._create_hash_from_expression(select_expression)
@@ -330,13 +329,12 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.sql(
- dialect=SparkSession().dialect
- )
+ expression.alias_or_name: expression.type.sql(dialect=dialect)
for expression in select_expression.expressions
},
- dialect=SparkSession().dialect,
+ dialect=dialect,
)
+
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
@@ -345,6 +343,7 @@ class DataFrame:
expression = exp.Cache(
this=cache_table, expression=select_expression, lazy=True, options=options
)
+
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
@@ -355,18 +354,17 @@ class DataFrame:
select_without_ctes = select_expression.copy()
select_without_ctes.set("with", None)
expression.set("expression", select_without_ctes)
+
if select_expression.ctes:
expression.set("with", exp.With(expressions=select_expression.ctes))
elif expression_type == exp.Select:
expression = select_expression
else:
raise ValueError(f"Invalid expression type: {expression_type}")
+
output_expressions.append(expression)
- return [
- expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
- for expression in output_expressions
- ]
+ return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@@ -542,12 +540,7 @@ class DataFrame:
"""
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
- x
- for x in [
- i if isinstance(col.expression, exp.Ordered) else None
- for i, col in enumerate(columns)
- ]
- if x is not None
+ i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
]
if ascending is None:
ascending = [True] * len(columns)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index a424ea4..6671c5b 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -306,7 +306,7 @@ def collect_list(col: ColumnOrName) -> Column:
def collect_set(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.SetAgg)
+ return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg)
def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 531ee17..4a33ef9 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -28,7 +28,7 @@ class SparkSession:
self.known_sequence_ids = set()
self.name_to_sequence_id_mapping = defaultdict(list)
self.incrementing_id = 1
- self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
+ self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
def __new__(cls, *args, **kwargs) -> SparkSession:
if cls._instance is None:
@@ -182,7 +182,7 @@ class SparkSession:
def getOrCreate(self) -> SparkSession:
spark = SparkSession()
- spark.dialect = Dialect.get_or_raise(self.dialect)()
+ spark.dialect = Dialect.get_or_raise(self.dialect)
return spark
@classproperty
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index fc9a3ae..2a9dde9 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
arg_max_or_min_no_count,
binary_from_function,
date_add_interval_sql,
@@ -23,6 +24,7 @@ from sqlglot.dialects.dialect import (
regexp_replace_sql,
rename_func,
timestrtotime_sql,
+ ts_or_ds_add_cast,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get, split_num_words
@@ -174,6 +176,44 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
+def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str:
+ return self.sql(
+ exp.Exists(
+ this=exp.select("1")
+ .from_(exp.Unnest(expressions=[expression.left]).as_("_unnest", table=["_col"]))
+ .where(exp.column("_col").eq(expression.right))
+ )
+ )
+
+
+def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> str:
+ return date_add_interval_sql("DATE", "ADD")(self, ts_or_ds_add_cast(expression))
+
+
+def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
+ expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
+ expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
+ unit = expression.args.get("unit") or "DAY"
+ return self.func("DATE_DIFF", expression.this, expression.expression, unit)
+
+
+def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str:
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale in (None, exp.UnixToTime.SECONDS):
+ return f"TIMESTAMP_SECONDS({timestamp})"
+ if scale == exp.UnixToTime.MILLIS:
+ return f"TIMESTAMP_MILLIS({timestamp})"
+ if scale == exp.UnixToTime.MICROS:
+ return f"TIMESTAMP_MICROS({timestamp})"
+ if scale == exp.UnixToTime.NANOS:
+ # We need to cast to INT64 because that's what BQ expects
+ return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))"
+
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
+
+
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
SUPPORTS_USER_DEFINED_TYPES = False
@@ -181,7 +221,7 @@ class BigQuery(Dialect):
LOG_BASE_FIRST = False
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
# bigquery udfs are case sensitive
NORMALIZE_FUNCTIONS = False
@@ -220,8 +260,7 @@ class BigQuery(Dialect):
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
- @classmethod
- def normalize_identifier(cls, expression: E) -> E:
+ def normalize_identifier(self, expression: E) -> E:
if isinstance(expression, exp.Identifier):
parent = expression.parent
while isinstance(parent, exp.Dot):
@@ -265,7 +304,6 @@ class BigQuery(Dialect):
"DECLARE": TokenType.COMMAND,
"FLOAT64": TokenType.DOUBLE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
- "INT64": TokenType.BIGINT,
"MODEL": TokenType.MODEL,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"RECORD": TokenType.STRUCT,
@@ -316,6 +354,15 @@ class BigQuery(Dialect):
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
+ "TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
+ this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
+ ),
+ "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
+ this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
+ ),
+ "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(
+ this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS
+ ),
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
}
@@ -358,6 +405,24 @@ class BigQuery(Dialect):
NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN}
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.END: lambda self: self._parse_as_command(self._prev),
+ TokenType.FOR: lambda self: self._parse_for_in(),
+ }
+
+ BRACKET_OFFSETS = {
+ "OFFSET": (0, False),
+ "ORDINAL": (1, False),
+ "SAFE_OFFSET": (0, True),
+ "SAFE_ORDINAL": (1, True),
+ }
+
+ def _parse_for_in(self) -> exp.ForIn:
+ this = self._parse_range()
+ self._match_text_seq("DO")
+ return self.expression(exp.ForIn, this=this, expression=self._parse_statement())
+
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_table_part(schema=schema) or self._parse_number()
@@ -419,6 +484,26 @@ class BigQuery(Dialect):
return json_object
+ def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ bracket = super()._parse_bracket(this)
+
+ if this is bracket:
+ return bracket
+
+ if isinstance(bracket, exp.Bracket):
+ for expression in bracket.expressions:
+ name = expression.name.upper()
+
+ if name not in self.BRACKET_OFFSETS:
+ break
+
+ offset, safe = self.BRACKET_OFFSETS[name]
+ bracket.set("offset", offset)
+ bracket.set("safe", safe)
+ expression.replace(expression.expressions[0])
+
+ return bracket
+
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
@@ -430,12 +515,14 @@ class BigQuery(Dialect):
NVL2_SUPPORTED = False
UNNEST_WITH_ORDINALITY = False
COLLATE_IS_FUNC = True
+ LIMIT_ONLY_LITERALS = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
+ exp.ArrayContains: _array_contains_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
@@ -498,10 +585,13 @@ class BigQuery(Dialect):
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
+ exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
- exp.TsOrDsAdd: date_add_interval_sql("DATE", "ADD"),
+ exp.TsOrDsAdd: _ts_or_ds_add_sql,
+ exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.Unhex: rename_func("FROM_HEX"),
+ exp.UnixToTime: _unix_to_time_sql,
exp.Values: _derived_table_values_to_unnest,
exp.VariancePop: rename_func("VAR_POP"),
}
@@ -671,6 +761,23 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ expressions = expression.expressions
+ expressions_sql = ", ".join(self.sql(e) for e in expressions)
+ offset = expression.args.get("offset")
+
+ if offset == 0:
+ expressions_sql = f"OFFSET({expressions_sql})"
+ elif offset == 1:
+ expressions_sql = f"ORDINAL({expressions_sql})"
+ else:
+ self.unsupported(f"Unsupported array offset: {offset}")
+
+ if expression.args.get("safe"):
+ expressions_sql = f"SAFE_{expressions_sql}"
+
+ return f"{self.sql(expression, 'this')}[{expressions_sql}]"
+
def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 394a922..da182aa 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -35,8 +35,8 @@ def _quantile_sql(self, e):
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
- STRICT_STRING_CONCAT = True
SUPPORTS_USER_DEFINED_TYPES = False
+ SAFE_DIVISION = True
ESCAPE_SEQUENCES = {
"\\0": "\0",
@@ -63,11 +63,7 @@ class ClickHouse(Dialect):
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"GLOBAL": TokenType.GLOBAL,
- "INT16": TokenType.SMALLINT,
"INT256": TokenType.INT256,
- "INT32": TokenType.INT,
- "INT64": TokenType.BIGINT,
- "INT8": TokenType.TINYINT,
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
"MAP": TokenType.MAP,
"NESTED": TokenType.NESTED,
@@ -112,6 +108,7 @@ class ClickHouse(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ "ARRAYJOIN": lambda self: self.expression(exp.Explode, this=self._parse_expression()),
"QUANTILE": lambda self: self._parse_quantile(),
}
@@ -223,12 +220,13 @@ class ClickHouse(Dialect):
except ParseError:
# WITH <expression> AS <identifier>
self._retreat(index)
- statement = self._parse_statement()
- if statement and isinstance(statement.this, exp.Alias):
- self.raise_error("Expected CTE to have alias")
-
- return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
+ return self.expression(
+ exp.CTE,
+ this=self._parse_field(),
+ alias=self._parse_table_alias(),
+ scalar=True,
+ )
def _parse_join_parts(
self,
@@ -385,9 +383,11 @@ class ClickHouse(Dialect):
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
+ exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
+ exp.Nullif: rename_func("nullIf"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
@@ -459,19 +459,11 @@ class ClickHouse(Dialect):
return super().datatype_sql(expression)
- def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
- # Clickhouse errors out if we try to cast a NULL value to TEXT
- return self.func(
- "CONCAT",
- *[
- exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
- for e in t.cast(t.List[exp.Condition], expression.expressions)
- ],
- )
-
def cte_sql(self, expression: exp.CTE) -> str:
- if isinstance(expression.this, exp.Alias):
- return self.sql(expression, "this")
+ if expression.args.get("scalar"):
+ this = self.sql(expression, "this")
+ alias = self.sql(expression, "alias")
+ return f"{this} AS {alias}"
return super().cte_sql(expression)
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index b777db0..1c10a8b 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,13 +1,18 @@
from __future__ import annotations
from sqlglot import exp, transforms
-from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql
+from sqlglot.dialects.dialect import (
+ date_delta_sql,
+ parse_date_delta,
+ timestamptrunc_sql,
+)
from sqlglot.dialects.spark import Spark
-from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
from sqlglot.tokens import TokenType
class Databricks(Spark):
+ SAFE_DIVISION = False
+
class Parser(Spark.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = True
@@ -27,8 +32,8 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
- exp.DateAdd: generate_date_delta_with_unit_sql,
- exp.DateDiff: generate_date_delta_with_unit_sql,
+ exp.DateAdd: date_delta_sql("DATEADD"),
+ exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DatetimeAdd: lambda self, e: self.func(
"TIMESTAMPADD", e.text("unit"), e.expression, e.this
),
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 21e7889..c7cea64 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,14 +1,14 @@
from __future__ import annotations
import typing as t
-from enum import Enum
+from enum import Enum, auto
from functools import reduce
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
-from sqlglot.helper import flatten, seq_get
+from sqlglot.helper import AutoName, flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@@ -16,6 +16,9 @@ from sqlglot.trie import new_trie
B = t.TypeVar("B", bound=exp.Binary)
+DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
+DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
+
class Dialects(str, Enum):
DIALECT = ""
@@ -43,6 +46,15 @@ class Dialects(str, Enum):
Doris = "doris"
+class NormalizationStrategy(str, AutoName):
+ """Specifies the strategy according to which identifiers should be normalized."""
+
+ LOWERCASE = auto() # Unquoted identifiers are lowercased
+ UPPERCASE = auto() # Unquoted identifiers are uppercased
+ CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
+ CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
+
+
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
@@ -106,26 +118,8 @@ class _Dialect(type):
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
- dialect_properties = {
- **{
- k: v
- for k, v in vars(klass).items()
- if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
- },
- "TOKENIZER_CLASS": klass.tokenizer_class,
- }
-
if enum not in ("", "bigquery"):
- dialect_properties["SELECT_KINDS"] = ()
-
- # Pass required dialect properties to the tokenizer, parser and generator classes
- for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
- for name, value in dialect_properties.items():
- if hasattr(subclass, name):
- setattr(subclass, name, value)
-
- if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
- klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
+ klass.generator_class.SELECT_KINDS = ()
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
@@ -133,8 +127,6 @@ class _Dialect(type):
TokenType.SEMI,
}
- klass.generator_class.can_identify = klass.can_identify
-
return klass
@@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
- # Determines whether or not unquoted identifiers are resolved as uppercase
- # When set to None, it means that the dialect treats all identifiers as case-insensitive
- RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
+ # Specifies the strategy according to which identifiers should be normalized.
+ NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
@@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect):
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
NULL_ORDERING = "nulls_are_small"
+ # Whether the behavior of a / b depends on the types of a and b.
+ # False means a / b is always float division.
+ # True means a / b is integer division if both a and b are integers.
+ TYPED_DIVISION = False
+
+ # False means 1 / 0 throws an error.
+ # True means 1 / 0 returns null.
+ SAFE_DIVISION = False
+
+ # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
+ CONCAT_COALESCE = False
+
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
@@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect):
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
- # Autofilled
+ # --- Autofilled ---
+
tokenizer_class = Tokenizer
parser_class = Parser
generator_class = Generator
@@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- def __eq__(self, other: t.Any) -> bool:
- return type(self) == other
+ # Delimiters for quotes, identifiers and the corresponding escape characters
+ QUOTE_START = "'"
+ QUOTE_END = "'"
+ IDENTIFIER_START = '"'
+ IDENTIFIER_END = '"'
- def __hash__(self) -> int:
- return hash(type(self))
+ # Delimiters for bit, hex and byte literals
+ BIT_START: t.Optional[str] = None
+ BIT_END: t.Optional[str] = None
+ HEX_START: t.Optional[str] = None
+ HEX_END: t.Optional[str] = None
+ BYTE_START: t.Optional[str] = None
+ BYTE_END: t.Optional[str] = None
@classmethod
- def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
+ def get_or_raise(cls, dialect: DialectType) -> Dialect:
+ """
+ Look up a dialect in the global dialect registry and return it if it exists.
+
+ Args:
+ dialect: The target dialect. If this is a string, it can be optionally followed by
+ additional key-value pairs that are separated by commas and are used to specify
+ dialect settings, such as whether the dialect's identifiers are case-sensitive.
+
+ Example:
+ >>> dialect = dialect_class = get_or_raise("duckdb")
+ >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
+
+ Returns:
+ The corresponding Dialect instance.
+ """
+
if not dialect:
- return cls
+ return cls()
if isinstance(dialect, _Dialect):
- return dialect
+ return dialect()
if isinstance(dialect, Dialect):
- return dialect.__class__
+ return dialect
+ if isinstance(dialect, str):
+ try:
+ dialect_name, *kv_pairs = dialect.split(",")
+ kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
+ except ValueError:
+ raise ValueError(
+ f"Invalid dialect format: '{dialect}'. "
+ "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
+ )
+
+ result = cls.get(dialect_name.strip())
+ if not result:
+ raise ValueError(f"Unknown dialect '{dialect_name}'.")
- result = cls.get(dialect)
- if not result:
- raise ValueError(f"Unknown dialect '{dialect}'")
+ return result(**kwargs)
- return result
+ raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
@classmethod
def format_time(
@@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect):
return expression
- @classmethod
- def normalize_identifier(cls, expression: E) -> E:
+ def __init__(self, **kwargs) -> None:
+ normalization_strategy = kwargs.get("normalization_strategy")
+
+ if normalization_strategy is None:
+ self.normalization_strategy = self.NORMALIZATION_STRATEGY
+ else:
+ self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
+
+ def __eq__(self, other: t.Any) -> bool:
+ # Does not currently take dialect state into account
+ return type(self) == other
+
+ def __hash__(self) -> int:
+ # Does not currently take dialect state into account
+ return hash(type(self))
+
+ def normalize_identifier(self, expression: E) -> E:
"""
- Normalizes an unquoted identifier to either lower or upper case, thus essentially
- making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
- they will be normalized to lowercase regardless of being quoted or not.
+ Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
+
+ For example, an identifier like FoO would be resolved as foo in Postgres, because it
+ lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
+ it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
+ and so any normalization would be prohibited in order to avoid "breaking" the identifier.
+
+ There are also dialects like Spark, which are case-insensitive even when quotes are
+ present, and dialects like MySQL, whose resolution rules match those employed by the
+ underlying operating system, for example they may always be case-sensitive in Linux.
+
+ Finally, the normalization behavior of some engines can even be controlled through flags,
+ like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
+
+ SQLGlot aims to understand and handle all of these different behaviors gracefully, so
+ that it can analyze queries in the optimizer and successfully capture their semantics.
"""
- if isinstance(expression, exp.Identifier) and (
- not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
+ if (
+ isinstance(expression, exp.Identifier)
+ and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
+ and (
+ not expression.quoted
+ or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
+ )
):
expression.set(
"this",
expression.this.upper()
- if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
else expression.this.lower(),
)
return expression
- @classmethod
- def case_sensitive(cls, text: str) -> bool:
+ def case_sensitive(self, text: str) -> bool:
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
- if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
+ if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
return False
- unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
+ unsafe = (
+ str.islower
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
+ else str.isupper
+ )
return any(unsafe(char) for char in text)
- @classmethod
- def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
+ def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
"""Checks if text can be identified given an identify option.
Args:
@@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect):
return True
if identify == "safe":
- return not cls.case_sensitive(text)
+ return not self.case_sensitive(text)
return False
- @classmethod
- def quote_identifier(cls, expression: E, identify: bool = True) -> E:
+ def quote_identifier(self, expression: E, identify: bool = True) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
- identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
+ identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression
@@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
- self._tokenizer = self.tokenizer_class()
+ self._tokenizer = self.tokenizer_class(dialect=self)
return self._tokenizer
def parser(self, **opts) -> Parser:
- return self.parser_class(**opts)
+ return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> Generator:
- return self.generator_class(**opts)
+ return self.generator_class(dialect=self, **opts)
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
@@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return _ts_or_ds_to_date_sql
-def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
+def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex
return self.func(name, expression.this, expression.expression)
return _arg_max_or_min_sql
+
+
+def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
+ this = expression.this.copy()
+
+ return_type = expression.return_type
+ if return_type.is_type(exp.DataType.Type.DATE):
+ # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
+ # can truncate timestamp strings, because some dialects can't cast them to DATE
+ this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
+
+ expression.this.replace(exp.cast(this, return_type))
+ return expression
+
+
+def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
+ def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
+ if cast and isinstance(expression, exp.TsOrDsAdd):
+ expression = ts_or_ds_add_cast(expression)
+
+ return self.func(
+ name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
+ )
+
+ return _delta_sql
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index bd7e0f2..11af17b 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -19,6 +19,7 @@ class Doris(MySQL):
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
+ "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_TRUNC": parse_timestamp_trunc,
"REGEXP": exp.RegexpLike.from_arg_list,
}
@@ -47,7 +48,7 @@ class Doris(MySQL):
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
- exp.SetAgg: rename_func("COLLECT_SET"),
+ exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 42453fd..70c96f8 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -43,6 +43,8 @@ class Drill(Dialect):
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
+ TYPED_DIVISION = True
+ CONCAT_COALESCE = True
TIME_MAPPING = {
"y": "%Y",
@@ -83,7 +85,6 @@ class Drill(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
- CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index d8d9f90..b94e3a6 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -2,9 +2,10 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
arrow_json_extract_scalar_sql,
@@ -36,7 +37,8 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
+ interval = self.sql(exp.Interval(this=expression.expression, unit=unit))
+ return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
@@ -84,7 +86,8 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
args = [
- f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
+ f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}"
+ for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
@@ -105,17 +108,35 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
return f"CAST({sql} AS TEXT)"
+def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str:
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale in (None, exp.UnixToTime.SECONDS):
+ return f"TO_TIMESTAMP({timestamp})"
+ if scale == exp.UnixToTime.MILLIS:
+ return f"EPOCH_MS({timestamp})"
+ if scale == exp.UnixToTime.MICROS:
+ return f"MAKE_TIMESTAMP({timestamp})"
+ if scale == exp.UnixToTime.NANOS:
+ return f"TO_TIMESTAMP({timestamp} / 1000000000)"
+
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
+
+
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
+ SAFE_DIVISION = True
+ INDEX_OFFSET = 1
+ CONCAT_COALESCE = True
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- ":=": TokenType.EQ,
"//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
@@ -124,8 +145,6 @@ class DuckDB(Dialect):
"CHAR": TokenType.TEXT,
"CHARACTER VARYING": TokenType.TEXT,
"EXCLUDE": TokenType.EXCEPT,
- "HUGEINT": TokenType.INT128,
- "INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
"PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
@@ -141,8 +160,6 @@ class DuckDB(Dialect):
}
class Parser(parser.Parser):
- CONCAT_NULL_OUTPUTS_STRING = True
-
BITWISE = {
**parser.Parser.BITWISE,
TokenType.TILDA: exp.RegexpLike,
@@ -150,6 +167,7 @@ class DuckDB(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "ARRAY_HAS": exp.ArrayContains.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
@@ -157,13 +175,23 @@ class DuckDB(Dialect):
"DATE_DIFF": _parse_date_diff,
"DATE_TRUNC": date_trunc_to_time,
"DATETRUNC": date_trunc_to_time,
+ "DECODE": lambda args: exp.Decode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
+ "ENCODE": lambda args: exp.Encode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
- this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000))
+ this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
+ "LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
+ "MAKE_TIMESTAMP": lambda args: exp.UnixToTime(
+ this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
+ ),
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
@@ -192,15 +220,8 @@ class DuckDB(Dialect):
"XOR": binary_from_function(exp.BitwiseXor),
}
- FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
- "DECODE": lambda self: self.expression(
- exp.Decode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8")
- ),
- "ENCODE": lambda self: self.expression(
- exp.Encode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8")
- ),
- }
+ FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
+ FUNCTION_PARSERS.pop("DECODE", None)
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
@@ -277,6 +298,7 @@ class DuckDB(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
+ exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@@ -294,6 +316,9 @@ class DuckDB(Dialect):
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: rename_func("QUANTILE_CONT"),
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
+ # DuckDB doesn't allow qualified columns inside of PIVOT expressions.
+ # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
+ exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: lambda self, e: self.func(
@@ -322,9 +347,15 @@ class DuckDB(Dialect):
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
+ exp.TsOrDsDiff: lambda self, e: self.func(
+ "DATE_DIFF",
+ f"'{e.args.get('unit') or 'day'}'",
+ exp.cast(e.expression, "TIMESTAMP"),
+ exp.cast(e.this, "TIMESTAMP"),
+ ),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
- exp.UnixToTime: rename_func("TO_TIMESTAMP"),
+ exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
exp.VariancePop: rename_func("VAR_POP"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 3b1c8de..0723e37 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -4,10 +4,13 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
+ DATE_ADD_OR_SUB,
Dialect,
+ NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
create_with_partitions_sql,
+ datestrtodate_sql,
format_time_lambda,
if_sql,
is_parse_json,
@@ -76,7 +79,10 @@ def _create_sql(self, expression: exp.Create) -> str:
return create_with_partitions_sql(self, expression)
-def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
+ if isinstance(expression, exp.TsOrDsAdd) and not expression.unit:
+ return self.func("DATE_ADD", expression.this, expression.expression)
+
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
@@ -95,7 +101,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -
return self.func(func, expression.this, modified_increment)
-def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
+def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff | exp.TsOrDsDiff) -> str:
unit = expression.text("unit").upper()
factor = TIME_DIFF_FACTOR.get(unit)
@@ -111,25 +117,31 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
- if months_between:
- # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part
- diff_sql = f"CAST({diff_sql} AS INT)"
+ if months_between or multiplier_sql:
+ # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part.
+ # For the same reason, we want to truncate if there's a divisor present.
+ diff_sql = f"CAST({diff_sql}{multiplier_sql} AS INT)"
- return f"{diff_sql}{multiplier_sql}"
+ return diff_sql
def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
- if is_parse_json(this) and this.this.is_string:
- # Since FROM_JSON requires a nested type, we always wrap the json string with
- # an array to ensure that "naked" strings like "'a'" will be handled correctly
- wrapped_json = exp.Literal.string(f"[{this.this.name}]")
- from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json))
- to_json = self.func("TO_JSON", from_json)
+ if is_parse_json(this):
+ if this.this.is_string:
+ # Since FROM_JSON requires a nested type, we always wrap the json string with
+ # an array to ensure that "naked" strings like "'a'" will be handled correctly
+ wrapped_json = exp.Literal.string(f"[{this.this.name}]")
+
+ from_json = self.func(
+ "FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json)
+ )
+ to_json = self.func("TO_JSON", from_json)
- # This strips the [, ] delimiters of the dummy array printed by TO_JSON
- return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1")
+ # This strips the [, ] delimiters of the dummy array printed by TO_JSON
+ return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1")
+ return self.sql(this)
return self.func("TO_JSON", this, expression.args.get("options"))
@@ -175,6 +187,8 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
return f"TO_DATE({this}, {time_format})"
+ if isinstance(expression.this, exp.TsOrDsToDate):
+ return this
return f"TO_DATE({this})"
@@ -182,9 +196,10 @@ class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
SUPPORTS_USER_DEFINED_TYPES = False
+ SAFE_DIVISION = True
# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_MAPPING = {
"y": "%Y",
@@ -241,10 +256,10 @@ class Hive(Dialect):
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
"MSCK REPAIR": TokenType.COMMAND,
- "REFRESH": TokenType.COMMAND,
- "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
+ "REFRESH": TokenType.REFRESH,
"TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
"VERSION AS OF": TokenType.VERSION_SNAPSHOT,
+ "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
NUMERIC_LITERALS = {
@@ -264,7 +279,7 @@ class Hive(Dialect):
**parser.Parser.FUNCTIONS,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
- "COLLECT_SET": exp.SetAgg.from_arg_list,
+ "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
@@ -411,7 +426,13 @@ class Hive(Dialect):
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
- SUPPORTS_NESTED_CTES = False
+
+ EXPRESSIONS_WITHOUT_NESTED_CTES = {
+ exp.Insert,
+ exp.Select,
+ exp.Subquery,
+ exp.Union,
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -445,7 +466,7 @@ class Hive(Dialect):
exp.With: no_recursive_cte_sql,
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
- exp.DateStrToDate: rename_func("TO_DATE"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
@@ -477,7 +498,7 @@ class Hive(Dialect):
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
- exp.SetAgg: rename_func("COLLECT_SET"),
+ exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@@ -491,7 +512,8 @@ class Hive(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.TsOrDsAdd: _add_date_sql,
+ exp.TsOrDsDiff: _date_diff_sql,
exp.TsOrDsToDate: _to_date_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToStr: lambda self, e: self.func(
@@ -571,6 +593,8 @@ class Hive(Dialect):
and not expression.expressions
):
expression = exp.DataType.build("text")
+ elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions:
+ expression.set("this", exp.DataType.Type.VARCHAR)
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
elif expression.is_type("float"):
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index c78aa9e..cfc6e83 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
arrow_json_extract_scalar_sql,
date_add_interval_sql,
datestrtodate_sql,
@@ -150,10 +151,18 @@ class MySQL(Dialect):
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
IDENTIFIERS_CAN_START_WITH_DIGIT = True
+ # We default to treating all identifiers as case-sensitive, since it matches MySQL's
+ # behavior on Linux systems. For MacOS and Windows systems, one can override this
+ # setting by specifying `dialect="mysql, normalization_strategy = lowercase"`.
+ #
+ # See also https://dev.mysql.com/doc/refman/8.2/en/identifier-case-sensitivity.html
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_SENSITIVE
+
TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
+ SAFE_DIVISION = True
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
TIME_MAPPING = {
@@ -264,11 +273,6 @@ class MySQL(Dialect):
TokenType.DPIPE: exp.Or,
}
- # MySQL uses || as a synonym to the logical OR operator
- # https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or
- BITWISE = parser.Parser.BITWISE.copy()
- BITWISE.pop(TokenType.DPIPE)
-
TABLE_ALIAS_TOKENS = (
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
)
@@ -451,7 +455,7 @@ class MySQL(Dialect):
self, kind: t.Optional[str] = None
) -> exp.IndexColumnConstraint:
if kind:
- self._match_texts({"INDEX", "KEY"})
+ self._match_texts(("INDEX", "KEY"))
this = self._parse_id_var(any_token=False)
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
@@ -514,7 +518,7 @@ class MySQL(Dialect):
log = self._parse_string() if self._match_text_seq("IN") else None
- if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
+ if this in ("BINLOG EVENTS", "RELAYLOG EVENTS"):
position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
@@ -671,6 +675,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("ADD"),
+ exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Week: _remove_ts_or_ds_to_date(),
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
@@ -763,7 +768,7 @@ class MySQL(Dialect):
target = self.sql(expression, "target")
target = f" {target}" if target else ""
- if expression.name in {"COLUMNS", "INDEX"}:
+ if expression.name in ("COLUMNS", "INDEX"):
target = f" FROM{target}"
elif expression.name == "GRANTS":
target = f" FOR{target}"
@@ -796,6 +801,14 @@ class MySQL(Dialect):
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
+ def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
+ dtype = self.sql(expression, "dtype")
+ if not dtype:
+ return super().altercolumn_sql(expression)
+
+ this = self.sql(expression, "this")
+ return f"MODIFY COLUMN {this} {dtype}"
+
def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str:
sql = self.sql(expression, arg)
return f" {prefix} {sql}" if sql else ""
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 6bdd8d6..51dbd53 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -3,7 +3,14 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
+from sqlglot.dialects.dialect import (
+ Dialect,
+ NormalizationStrategy,
+ format_time_lambda,
+ no_ilike_sql,
+ rename_func,
+ trim_sql,
+)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -30,12 +37,25 @@ def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
+def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
+ this = seq_get(args, 0)
+
+ if this and not this.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ annotate_types(this)
+ if this.is_type(*exp.DataType.TEMPORAL_TYPES):
+ return format_time_lambda(exp.TimeToStr, "oracle", default=True)(args)
+
+ return exp.ToChar.from_arg_list(args)
+
+
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
+ NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
@@ -64,11 +84,13 @@ class Oracle(Dialect):
}
class Parser(parser.Parser):
+ ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
+ "TO_CHAR": to_char,
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -130,6 +152,7 @@ class Oracle(Dialect):
TABLE_HINTS = False
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True
+ ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
@@ -192,6 +215,12 @@ class Oracle(Dialect):
)
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
+ def add_column_sql(self, expression: exp.AlterTable) -> str:
+ actions = self.expressions(expression, key="actions", flat=True)
+ if len(expression.args.get("actions", [])) > 1:
+ return f"ADD ({actions})"
+ return f"ADD {actions}"
+
class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@", "$", "#"}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 27c6851..fefddee 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
+ DATE_ADD_OR_SUB,
Dialect,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
@@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
timestamptrunc_sql,
timestrtotime_sql,
trim_sql,
+ ts_or_ds_add_cast,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
@@ -41,8 +43,11 @@ DATE_DIFF_FACTOR = {
}
-def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB], str]:
+ def func(self: Postgres.Generator, expression: DATE_ADD_OR_SUB) -> str:
+ if isinstance(expression, exp.TsOrDsAdd):
+ expression = ts_or_ds_add_cast(expression)
+
this = self.sql(expression, "this")
unit = expression.args.get("unit")
@@ -60,8 +65,8 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
- end = f"CAST({expression.this} AS TIMESTAMP)"
- start = f"CAST({expression.expression} AS TIMESTAMP)"
+ end = f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
+ start = f"CAST({self.sql(expression, 'expression')} AS TIMESTAMP)"
if factor is not None:
return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
@@ -69,7 +74,7 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
age = f"AGE({end}, {start})"
if unit == "WEEK":
- unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
+ unit = f"EXTRACT(days FROM ({end} - {start})) / 7"
elif unit == "MONTH":
unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
@@ -183,37 +188,43 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)
-def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
- """Remove table refs from columns in when statements."""
- if isinstance(expression, exp.Merge):
- alias = expression.this.args.get("alias")
+def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
+ def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
+ """Remove table refs from columns in when statements."""
+ if isinstance(expression, exp.Merge):
+ alias = expression.this.args.get("alias")
- normalize = (
- lambda identifier: Postgres.normalize_identifier(identifier).name
- if identifier
- else None
- )
+ normalize = (
+ lambda identifier: self.dialect.normalize_identifier(identifier).name
+ if identifier
+ else None
+ )
- targets = {normalize(expression.this.this)}
+ targets = {normalize(expression.this.this)}
- if alias:
- targets.add(normalize(alias.this))
+ if alias:
+ targets.add(normalize(alias.this))
- for when in expression.expressions:
- when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
- copy=False,
- )
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node,
+ copy=False,
+ )
- return expression
+ return expression
+
+ return transforms.preprocess([_remove_target_from_merge])(self, expression)
class Postgres(Dialect):
INDEX_OFFSET = 1
+ TYPED_DIVISION = True
+ CONCAT_COALESCE = True
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
+
TIME_MAPPING = {
"AM": "%p",
"PM": "%p",
@@ -263,6 +274,7 @@ class Postgres(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"CHARACTER VARYING": TokenType.VARCHAR,
+ "CONSTRAINT TRIGGER": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
@@ -277,6 +289,7 @@ class Postgres(Dialect):
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
+ "OPERATOR": TokenType.OPERATOR,
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
"REGCONFIG": TokenType.OBJECT_IDENTIFIER,
@@ -298,8 +311,6 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
- CONCAT_NULL_OUTPUTS_STRING = True
-
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
@@ -326,12 +337,13 @@ class Postgres(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
+ TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.DAT: lambda self, this: self.expression(
exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
),
- TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
+ TokenType.OPERATOR: lambda self, this: self._parse_operator(this),
}
STATEMENT_PARSERS = {
@@ -339,11 +351,28 @@ class Postgres(Dialect):
TokenType.END: lambda self: self._parse_commit_or_rollback(),
}
- def _parse_factor(self) -> t.Optional[exp.Expression]:
- return self._parse_tokens(self._parse_exponent, self.FACTOR)
+ def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ while True:
+ if not self._match(TokenType.L_PAREN):
+ break
+
+ op = ""
+ while self._curr and not self._match(TokenType.R_PAREN):
+ op += self._curr.text
+ self._advance()
+
+ this = self.expression(
+ exp.Operator,
+ comments=self._prev_comments,
+ this=this,
+ operator=op,
+ expression=self._parse_bitwise(),
+ )
+
+ if not self._match(TokenType.OPERATOR):
+ break
- def _parse_exponent(self) -> t.Optional[exp.Expression]:
- return self._parse_tokens(self._parse_unary, self.EXPONENT)
+ return this
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
@@ -405,7 +434,7 @@ class Postgres(Dialect):
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
- exp.Merge: transforms.preprocess([_remove_target_from_merge]),
+ exp.Merge: _merge_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
@@ -434,6 +463,8 @@ class Postgres(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
+ exp.TsOrDsAdd: _date_add_sql("+"),
+ exp.TsOrDsDiff: _date_diff_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index ded3655..10a6074 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -5,9 +5,11 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
binary_from_function,
bool_xor_sql,
date_trunc_to_time,
+ datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
if_sql,
@@ -22,6 +24,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
+ ts_or_ds_add_cast,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
@@ -95,17 +98,16 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
- this = expression.this
+ expression = ts_or_ds_add_cast(expression)
+ unit = exp.Literal.string(expression.text("unit") or "day")
+ return self.func("DATE_ADD", unit, expression.expression, expression.this)
- if not isinstance(this, exp.CurrentDate):
- this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")
- return self.func(
- "DATE_ADD",
- exp.Literal.string(expression.text("unit") or "day"),
- expression.expression,
- this,
- )
+def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
+ this = exp.cast(expression.this, "TIMESTAMP")
+ expr = exp.cast(expression.expression, "TIMESTAMP")
+ unit = exp.Literal.string(expression.text("unit") or "day")
+ return self.func("DATE_DIFF", unit, expr, this)
def _approx_percentile(args: t.List) -> exp.Expression:
@@ -136,11 +138,11 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
-def _parse_element_at(args: t.List) -> exp.SafeBracket:
+def _parse_element_at(args: t.List) -> exp.Bracket:
this = seq_get(args, 0)
index = seq_get(args, 1)
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
- return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
+ return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
@@ -168,6 +170,22 @@ def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) ->
return rename_func("ARBITRARY")(self, expression)
+def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str:
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale in (None, exp.UnixToTime.SECONDS):
+ return rename_func("FROM_UNIXTIME")(self, expression)
+ if scale == exp.UnixToTime.MILLIS:
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
+ if scale == exp.UnixToTime.MICROS:
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
+ if scale == exp.UnixToTime.NANOS:
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
+
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
+
+
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@@ -175,11 +193,12 @@ class Presto(Dialect):
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
+ TYPED_DIVISION = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
# https://github.com/prestodb/presto/issues/2863
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@@ -229,6 +248,7 @@ class Presto(Dialect):
),
"ROW": exp.Struct.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
+ "SET_AGG": exp.ArrayUniqueAgg.from_arg_list,
"SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
@@ -253,6 +273,7 @@ class Presto(Dialect):
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
LIMIT_ONLY_LITERALS = True
+ SUPPORTS_SINGLE_ARG_CONCAT = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@@ -284,6 +305,7 @@ class Presto(Dialect):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
+ exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
@@ -298,7 +320,7 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
- exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
@@ -330,9 +352,6 @@ class Presto(Dialect):
exp.Quantile: _quantile_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.Right: right_to_substring_sql,
- exp.SafeBracket: lambda self, e: self.func(
- "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
- ),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@@ -361,10 +380,11 @@ class Presto(Dialect):
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
+ exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
- exp.UnixToTime: rename_func("FROM_UNIXTIME"),
+ exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
@@ -374,8 +394,24 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ if expression.args.get("safe"):
+ return self.func(
+ "ELEMENT_AT",
+ expression.this,
+ seq_get(
+ apply_index_offset(
+ expression.this,
+ expression.expressions,
+ 1 - expression.args.get("offset", 0),
+ ),
+ 0,
+ ),
+ )
+ return super().bracket_sql(expression)
+
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions):
+ if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 6c7ba35..7382e7c 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -4,8 +4,10 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
+ NormalizationStrategy,
concat_to_dpipe_sql,
concat_ws_to_dpipe_sql,
+ date_delta_sql,
generatedasidentitycolumnconstraint_sql,
rename_func,
ts_or_ds_to_date_sql,
@@ -14,30 +16,28 @@ from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
-def _parse_date_add(args: t.List) -> exp.DateAdd:
- return exp.DateAdd(
- this=exp.TsOrDsToDate(this=seq_get(args, 2)),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
- )
+def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
+ def _parse_delta(args: t.List) -> E:
+ expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
+ if expr_type is exp.TsOrDsAdd:
+ expr.set("return_type", exp.DataType.build("TIMESTAMP"))
+ return expr
-def _parse_datediff(args: t.List) -> exp.DateDiff:
- return exp.DateDiff(
- this=exp.TsOrDsToDate(this=seq_get(args, 2)),
- expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
- unit=seq_get(args, 0),
- )
+ return _parse_delta
class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
SUPPORTS_USER_DEFINED_TYPES = False
INDEX_OFFSET = 0
@@ -52,15 +52,16 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
- "ADD_MONTHS": lambda args: exp.DateAdd(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ "ADD_MONTHS": lambda args: exp.TsOrDsAdd(
+ this=seq_get(args, 0),
expression=seq_get(args, 1),
unit=exp.var("month"),
+ return_type=exp.DataType.build("TIMESTAMP"),
),
- "DATEADD": _parse_date_add,
- "DATE_ADD": _parse_date_add,
- "DATEDIFF": _parse_datediff,
- "DATE_DIFF": _parse_datediff,
+ "DATEADD": _parse_date_delta(exp.TsOrDsAdd),
+ "DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
+ "DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
+ "DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
"LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@@ -169,12 +170,8 @@ class Redshift(Postgres):
exp.ConcatWs: concat_ws_to_dpipe_sql,
exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
- exp.DateAdd: lambda self, e: self.func(
- "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
- ),
- exp.DateDiff: lambda self, e: self.func(
- "DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this
- ),
+ exp.DateAdd: date_delta_sql("DATEADD"),
+ exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
@@ -183,11 +180,12 @@ class Redshift(Postgres):
exp.JSONExtractScalar: _json_sql,
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
- exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.TsOrDsAdd: date_delta_sql("DATEADD"),
+ exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 01f7512..cdbc071 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -3,9 +3,12 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
+from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
binary_from_function,
+ date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
@@ -21,7 +24,6 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
-from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
@@ -50,7 +52,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
elif second_arg.name == "3":
timescale = exp.UnixToTime.MILLIS
elif second_arg.name == "9":
- timescale = exp.UnixToTime.MICROS
+ timescale = exp.UnixToTime.NANOS
return exp.UnixToTime(this=first_arg, scale=timescale)
@@ -95,14 +97,17 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
- if scale in [None, exp.UnixToTime.SECONDS]:
+ if scale in (None, exp.UnixToTime.SECONDS):
return f"TO_TIMESTAMP({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TO_TIMESTAMP({timestamp}, 3)"
if scale == exp.UnixToTime.MICROS:
+ return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
+ if scale == exp.UnixToTime.NANOS:
return f"TO_TIMESTAMP({timestamp}, 9)"
- raise ValueError("Improper scale for timestamp")
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
@@ -201,7 +206,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
+ NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
@@ -236,6 +241,18 @@ class Snowflake(Dialect):
"ff6": "%f",
}
+ def quote_identifier(self, expression: E, identify: bool = True) -> E:
+ # This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an
+ # unquoted DUAL keyword in a special way and does not map it to a user-defined table
+ if (
+ isinstance(expression, exp.Identifier)
+ and isinstance(expression.parent, exp.Table)
+ and expression.name.lower() == "dual"
+ ):
+ return t.cast(E, expression)
+
+ return super().quote_identifier(expression, identify=identify)
+
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
@@ -245,6 +262,9 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
+ "ARRAY_CONTAINS": lambda args: exp.ArrayContains(
+ this=seq_get(args, 1), expression=seq_get(args, 0)
+ ),
"ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries(
# ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive
start=seq_get(args, 0),
@@ -296,8 +316,8 @@ class Snowflake(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
- TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny),
- TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny),
+ TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
+ TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
}
ALTER_PARSERS = {
@@ -317,6 +337,11 @@ class Snowflake(Dialect):
TokenType.SHOW: lambda self: self._parse_show(),
}
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ "LOCATION": lambda self: self._parse_location(),
+ }
+
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
@@ -349,7 +374,7 @@ class Snowflake(Dialect):
table: t.Optional[exp.Expression] = None
if self._match_text_seq("@"):
table_name = "@"
- while True:
+ while self._curr:
self._advance()
table_name += self._prev.text
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
@@ -411,6 +436,20 @@ class Snowflake(Dialect):
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
+ def _parse_location(self) -> exp.LocationProperty:
+ self._match(TokenType.EQ)
+
+ parts = [self._parse_var(any_token=True)]
+
+ while self._match(TokenType.SLASH):
+ if self._curr and self._prev.end + 1 == self._curr.start:
+ parts.append(self._parse_var(any_token=True))
+ else:
+ parts.append(exp.Var(this=""))
+ return self.expression(
+ exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
+ )
+
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
@@ -457,6 +496,7 @@ class Snowflake(Dialect):
AGGREGATE_FILTER_SUPPORTED = False
SUPPORTS_TABLE_COPY = False
COLLATE_IS_FUNC = True
+ LIMIT_ONLY_LITERALS = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -464,15 +504,14 @@ class Snowflake(Dialect):
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
+ exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
- exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
- exp.DateDiff: lambda self, e: self.func(
- "DATEDIFF", e.text("unit"), e.expression, e.this
- ),
+ exp.DateAdd: date_delta_sql("DATEADD"),
+ exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -501,10 +540,11 @@ class Snowflake(Dialect):
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
- transforms.explode_to_unnest(0),
+ transforms.explode_to_unnest(),
transforms.eliminate_semi_and_anti_joins,
]
),
+ exp.SHA: rename_func("SHA1"),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
exp.StrPosition: lambda self, e: self.func(
@@ -524,6 +564,8 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
+ exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
+ exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -547,6 +589,20 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def trycast_sql(self, expression: exp.TryCast) -> str:
+ value = expression.this
+
+ if value.type is None:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ value = annotate_types(value)
+
+ if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN):
+ return super().trycast_sql(expression)
+
+ # TRY_CAST only works for string values in Snowflake
+ return self.cast_sql(expression)
+
def log_sql(self, expression: exp.Log) -> str:
if not expression.expression:
return self.func("LN", expression.this)
@@ -554,24 +610,28 @@ class Snowflake(Dialect):
return super().log_sql(expression)
def unnest_sql(self, expression: exp.Unnest) -> str:
- selects = ["value"]
unnest_alias = expression.args.get("alias")
-
offset = expression.args.get("offset")
- if offset:
- if unnest_alias:
- unnest_alias.append("columns", offset.pop())
-
- selects.append("index")
- subquery = exp.Subquery(
- this=exp.select(*selects).from_(
- f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
- ),
- )
+ columns = [
+ exp.to_identifier("seq"),
+ exp.to_identifier("key"),
+ exp.to_identifier("path"),
+ offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"),
+ seq_get(unnest_alias.columns if unnest_alias else [], 0)
+ or exp.to_identifier("value"),
+ exp.to_identifier("this"),
+ ]
+
+ if unnest_alias:
+ unnest_alias.set("columns", columns)
+ else:
+ unnest_alias = exp.TableAlias(this="_u", columns=columns)
+
+ explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
alias = self.sql(unnest_alias)
alias = f" AS {alias}" if alias else ""
- return f"{self.sql(subquery)}{alias}"
+ return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
scope = self.sql(expression, "scope")
@@ -632,3 +692,6 @@ class Snowflake(Dialect):
def swaptable_sql(self, expression: exp.SwapTable) -> str:
this = self.sql(expression, "this")
return f"SWAP WITH {this}"
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 1abfce6..ba73ac0 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -56,15 +56,17 @@ class Spark(Spark2):
def _parse_generated_as_identity(
self,
- ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
+ ) -> (
+ exp.GeneratedAsIdentityColumnConstraint
+ | exp.ComputedColumnConstraint
+ | exp.GeneratedAsRowColumnConstraint
+ ):
this = super()._parse_generated_as_identity()
if this.expression:
return self.expression(exp.ComputedColumnConstraint, this=this.expression)
return this
class Generator(Spark2.Generator):
- SUPPORTS_NESTED_CTES = True
-
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index da84bd8..aa09f53 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -48,8 +48,11 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
+ if scale == exp.UnixToTime.NANOS:
+ return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)"
- raise ValueError("Improper scale for timestamp")
+ self.unsupported(f"Unsupported scale for timestamp: {scale}.")
+ return ""
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
@@ -119,7 +122,11 @@ class Spark2(Hive):
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
- this=exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("timestamp")),
+ this=exp.cast_unless(
+ seq_get(args, 0) or exp.Var(this=""),
+ exp.DataType.build("timestamp"),
+ exp.DataType.build("timestamp"),
+ ),
zone=seq_get(args, 1),
),
"IIF": exp.If.from_arg_list,
@@ -224,6 +231,19 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
+ def struct_sql(self, expression: exp.Struct) -> str:
+ args = []
+ for arg in expression.expressions:
+ if isinstance(arg, self.KEY_VALUE_DEFINITONS):
+ if isinstance(arg, exp.Bracket):
+ args.append(exp.alias_(arg.this, arg.expressions[0].name))
+ else:
+ args.append(exp.alias_(arg.expression, arg.this.name))
+ else:
+ args.append(arg)
+
+ return self.func("STRUCT", *args)
+
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# spark2, spark, Databricks require a storage provider for temporary tables
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 1fa730d..e55a3b8 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
@@ -63,8 +64,10 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
SUPPORTS_SEMI_ANTI_JOIN = False
+ TYPED_DIVISION = True
+ SAFE_DIVISION = True
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@@ -124,7 +127,6 @@ class SQLite(Dialect):
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Pivot: no_pivot_sql,
- exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index e8162c2..141d9c0 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -9,6 +9,7 @@ from sqlglot.tokens import TokenType
class Teradata(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False
+ TYPED_DIVISION = True
TIME_MAPPING = {
"Y": "%Y",
@@ -33,8 +34,10 @@ class Teradata(Dialect):
class Tokenizer(tokens.Tokenizer):
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
+ # https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "**": TokenType.DSTAR,
"^=": TokenType.NEQ,
"BYTEINT": TokenType.SMALLINT,
"COLLECT": TokenType.COMMAND,
@@ -112,10 +115,16 @@ class Teradata(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ # https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Data-Type-Conversions/TRYCAST
+ "TRYCAST": parser.Parser.FUNCTION_PARSERS["TRY_CAST"],
"RANGE_N": lambda self: self._parse_rangen(),
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
+ EXPONENT = {
+ TokenType.DSTAR: exp.Pow,
+ }
+
def _parse_translate(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
@@ -177,6 +186,7 @@ class Teradata(Dialect):
exp.ArgMin: rename_func("MIN_BY"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.Pow: lambda self, e: self.binary(e, "**"),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
@@ -192,6 +202,9 @@ class Teradata(Dialect):
return super().cast_sql(expression, safe_prefix=safe_prefix)
+ def trycast_sql(self, expression: exp.TryCast) -> str:
+ return self.cast_sql(expression, safe_prefix="TRY")
+
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a281297..c3d4f0a 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -7,7 +7,9 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
any_value_to_max_sql,
+ date_delta_sql,
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
@@ -135,11 +137,7 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
return exp.func("HASHBYTES", *args)
-def generate_date_delta_with_unit_sql(
- self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
-) -> str:
- func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
- return self.func(func, expression.text("unit"), expression.expression, expression.this)
+DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
@@ -153,6 +151,11 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt
)
)
)
+
+ # There is no format for "quarter"
+ if fmt.name.lower() in DATEPART_ONLY_FORMATS:
+ return self.func("DATEPART", fmt.name, expression.this)
+
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
@@ -202,18 +205,50 @@ def _parse_date_delta(
return inner_func
+def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
+ """Ensures all (unnamed) output columns are aliased for CTEs and Subqueries."""
+ alias = expression.args.get("alias")
+
+ if (
+ isinstance(expression, (exp.CTE, exp.Subquery))
+ and isinstance(alias, exp.TableAlias)
+ and not alias.columns
+ ):
+ from sqlglot.optimizer.qualify_columns import qualify_outputs
+
+ # We keep track of the unaliased column projection indexes instead of the expressions
+ # themselves, because the latter are going to be replaced by new nodes when the aliases
+ # are added and hence we won't be able to reach these newly added Alias parents
+ subqueryable = expression.this
+ unaliased_column_indexes = (
+ i
+ for i, c in enumerate(subqueryable.selects)
+ if isinstance(c, exp.Column) and not c.alias
+ )
+
+ qualify_outputs(subqueryable)
+
+ # Preserve the quoting information of columns for newly added Alias nodes
+ subqueryable_selects = subqueryable.selects
+ for select_index in unaliased_column_indexes:
+ alias = subqueryable_selects[select_index]
+ column = alias.this
+ if isinstance(column.this, exp.Identifier):
+ alias.args["alias"].set("quoted", column.this.quoted)
+
+ return expression
+
+
class TSQL(Dialect):
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
- NULL_ORDERING = "nulls_are_small"
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
SUPPORTS_SEMI_ANTI_JOIN = False
LOG_BASE_FIRST = False
+ TYPED_DIVISION = True
+ CONCAT_COALESCE = True
TIME_MAPPING = {
"year": "%Y",
- "qq": "%q",
- "q": "%q",
- "quarter": "%q",
"dayofyear": "%j",
"day": "%d",
"dy": "%d",
@@ -320,6 +355,7 @@ class TSQL(Dialect):
IDENTIFIERS = ['"', ("[", "]")]
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
+ VAR_SINGLE_TOKENS = {"@", "$", "#"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -403,9 +439,7 @@ class TSQL(Dialect):
LOG_DEFAULTS_TO_LN = True
- CONCAT_NULL_OUTPUTS_STRING = True
-
- ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+ ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -433,7 +467,7 @@ class TSQL(Dialect):
"""
rollback = self._prev.token_type == TokenType.ROLLBACK
- self._match_texts({"TRAN", "TRANSACTION"})
+ self._match_texts(("TRAN", "TRANSACTION"))
this = self._parse_id_var()
if rollback:
@@ -579,23 +613,35 @@ class TSQL(Dialect):
return super()._parse_if()
def _parse_unique(self) -> exp.UniqueColumnConstraint:
- return self.expression(
- exp.UniqueColumnConstraint,
- this=None
- if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
- else self._parse_schema(self._parse_id_var(any_token=False)),
- )
+ if self._match_texts(("CLUSTERED", "NONCLUSTERED")):
+ this = self.CONSTRAINT_PARSERS[self._prev.text.upper()](self)
+ else:
+ this = self._parse_schema(self._parse_id_var(any_token=False))
+
+ return self.expression(exp.UniqueColumnConstraint, this=this)
class Generator(generator.Generator):
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
NVL2_SUPPORTED = False
- ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+ ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
COMPUTED_COLUMN_WITH_TYPE = False
- SUPPORTS_NESTED_CTES = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
+ ENSURE_BOOLS = True
+ NULL_ORDERING_SUPPORTED = False
+ SUPPORTS_SINGLE_ARG_CONCAT = False
+
+ EXPRESSIONS_WITHOUT_NESTED_CTES = {
+ exp.Delete,
+ exp.Insert,
+ exp.Merge,
+ exp.Select,
+ exp.Subquery,
+ exp.Union,
+ exp.Update,
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -614,14 +660,16 @@ class TSQL(Dialect):
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
- exp.DateAdd: generate_date_delta_with_unit_sql,
- exp.DateDiff: generate_date_delta_with_unit_sql,
+ exp.DateAdd: date_delta_sql("DATEADD"),
+ exp.DateDiff: date_delta_sql("DATEDIFF"),
+ exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
+ exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@@ -633,15 +681,16 @@ class TSQL(Dialect):
transforms.eliminate_qualify,
]
),
+ exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
- "HASHBYTES",
- exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
- e.this,
+ "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
+ exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
@@ -690,8 +739,21 @@ class TSQL(Dialect):
table = expression.find(exp.Table)
+ # Convert CTAS statement to SELECT .. INTO ..
if kind == "TABLE" and expression.expression:
- sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp"
+ ctas_with = expression.expression.args.get("with")
+ if ctas_with:
+ ctas_with = ctas_with.pop()
+
+ subquery = expression.expression
+ if isinstance(subquery, exp.Subqueryable):
+ subquery = subquery.subquery()
+
+ select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True))
+ select_into.set("into", exp.Into(this=table))
+ select_into.set("with", ctas_with)
+
+ sql = self.sql(select_into)
if exists:
identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index bf2941c..b79a551 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -139,10 +139,16 @@ def interval(this, unit):
return datetime.timedelta(**{unit: float(this)})
+@null_if_any("this", "expression")
+def arrayjoin(this, expression, null=None):
+ return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
+
+
ENV = {
"exp": exp,
# aggs
"ARRAYAGG": list,
+ "ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))),
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
@@ -152,6 +158,7 @@ ENV = {
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
+ "ARRAYJOIN": arrayjoin,
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
@@ -203,4 +210,9 @@ ENV = {
"CURRENTDATE": datetime.date.today,
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
"TRIM": null_if_any(lambda this, e=None: this.strip(e)),
+ "STRUCT": lambda *args: {
+ args[x]: args[x + 1]
+ for x in range(0, len(args), 2)
+ if (args[x + 1] is not None and args[x] is not None)
+ },
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index d2ae79d..e1e597d 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -397,6 +397,20 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
+def _div_sql(self: generator.Generator, e: exp.Div) -> str:
+ denominator = self.sql(e, "expression")
+
+ if e.args.get("safe"):
+ denominator += " or None"
+
+ sql = f"DIV({self.sql(e, 'this')}, {denominator})"
+
+ if e.args.get("typed"):
+ sql = f"int({sql})"
+
+ return sql
+
+
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\"]
@@ -413,7 +427,11 @@ class Python(Dialect):
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
+ exp.Concat: lambda self, e: self.func(
+ "SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
+ ),
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
+ exp.Div: _div_sql,
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 7931535..87699f8 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -120,20 +120,22 @@ def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict
depth = dict_depth(d)
if depth > 1:
return {
- normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect)
+ normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables(
+ v, dialect=dialect
+ )
for k, v in d.items()
}
result = {}
for table_name, table in d.items():
- table_name = normalize_name(table_name, dialect=dialect)
+ table_name = normalize_name(table_name, dialect=dialect).name
if isinstance(table, Table):
result[table_name] = table
else:
table = [
{
- normalize_name(column_name, dialect=dialect): value
+ normalize_name(column_name, dialect=dialect).name: value
for column_name, value in row.items()
}
for row in table
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 99ebfb3..99722be 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -53,6 +53,7 @@ class _Expression(type):
SQLGLOT_META = "sqlglot.meta"
+TABLE_PARTS = ("this", "db", "catalog")
class Expression(metaclass=_Expression):
@@ -134,7 +135,7 @@ class Expression(metaclass=_Expression):
return self.args.get("expression")
@property
- def expressions(self):
+ def expressions(self) -> t.List[t.Any]:
"""
Retrieves the argument with key "expressions".
"""
@@ -238,6 +239,9 @@ class Expression(metaclass=_Expression):
dtype = DataType.build(dtype)
self._type = dtype # type: ignore
+ def is_type(self, *dtypes) -> bool:
+ return self.type is not None and self.type.is_type(*dtypes)
+
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
@@ -481,7 +485,7 @@ class Expression(metaclass=_Expression):
def flatten(self, unnest=True):
"""
- Returns a generator which yields child nodes who's parents are the same class.
+ Returns a generator which yields child nodes whose parents are the same class.
A AND B AND C -> [A, B, C]
"""
@@ -508,7 +512,7 @@ class Expression(metaclass=_Expression):
"""
from sqlglot.dialects import Dialect
- return Dialect.get_or_raise(dialect)().generate(self, **opts)
+ return Dialect.get_or_raise(dialect).generate(self, **opts)
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
@@ -821,6 +825,12 @@ class Expression(metaclass=_Expression):
def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)
+ def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div:
+ div = self._binop(Div, other)
+ div.args["typed"] = typed
+ div.args["safe"] = safe
+ return div
+
def __lt__(self, other: t.Any) -> LT:
return self._binop(LT, other)
@@ -1000,7 +1010,6 @@ class UDTF(DerivedTable, Unionable):
class Cache(Expression):
arg_types = {
- "with": False,
"this": True,
"lazy": False,
"options": False,
@@ -1012,6 +1021,10 @@ class Uncache(Expression):
arg_types = {"this": True, "exists": False}
+class Refresh(Expression):
+ pass
+
+
class DDL(Expression):
@property
def ctes(self):
@@ -1033,6 +1046,43 @@ class DDL(Expression):
return []
+class DML(Expression):
+ def returning(
+ self,
+ expression: ExpOrStr,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> DML:
+ """
+ Set the RETURNING expression. Not supported by all dialects.
+
+ Example:
+ >>> delete("tbl").returning("*", dialect="postgres").sql()
+ 'DELETE FROM tbl RETURNING *'
+
+ Args:
+ expression: the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ Delete: the modified expression.
+ """
+ return _apply_builder(
+ expression=expression,
+ instance=self,
+ arg="returning",
+ prefix="RETURNING",
+ dialect=dialect,
+ copy=copy,
+ into=Returning,
+ **opts,
+ )
+
+
class Create(DDL):
arg_types = {
"with": False,
@@ -1133,8 +1183,10 @@ class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
+# clickhouse supports scalar ctes
+# https://clickhouse.com/docs/en/sql-reference/statements/select/with
class CTE(DerivedTable):
- arg_types = {"this": True, "alias": True}
+ arg_types = {"this": True, "alias": True, "scalar": False}
class TableAlias(Expression):
@@ -1297,6 +1349,10 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
+class PeriodForSystemTimeConstraint(ColumnConstraintKind):
+ arg_types = {"this": True, "expression": True}
+
+
class CaseSpecificColumnConstraint(ColumnConstraintKind):
arg_types = {"not_": True}
@@ -1351,6 +1407,10 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
}
+class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
+ arg_types = {"start": True, "hidden": False}
+
+
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
class IndexColumnConstraint(ColumnConstraintKind):
arg_types = {
@@ -1383,6 +1443,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind):
pass
+# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters
+class TransformColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
@@ -1413,7 +1478,7 @@ class Constraint(Expression):
arg_types = {"this": True, "expressions": True}
-class Delete(Expression):
+class Delete(DML):
arg_types = {
"with": False,
"this": False,
@@ -1496,41 +1561,6 @@ class Delete(Expression):
**opts,
)
- def returning(
- self,
- expression: ExpOrStr,
- dialect: DialectType = None,
- copy: bool = True,
- **opts,
- ) -> Delete:
- """
- Set the RETURNING expression. Not supported by all dialects.
-
- Example:
- >>> delete("tbl").returning("*", dialect="postgres").sql()
- 'DELETE FROM tbl RETURNING *'
-
- Args:
- expression: the SQL code strings to parse.
- If an `Expression` instance is passed, it will be used as-is.
- dialect: the dialect used to parse the input expressions.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- Delete: the modified expression.
- """
- return _apply_builder(
- expression=expression,
- instance=self,
- arg="returning",
- prefix="RETURNING",
- dialect=dialect,
- copy=copy,
- into=Returning,
- **opts,
- )
-
class Drop(Expression):
arg_types = {
@@ -1648,7 +1678,7 @@ class Index(Expression):
}
-class Insert(DDL):
+class Insert(DDL, DML):
arg_types = {
"with": False,
"this": True,
@@ -2259,6 +2289,11 @@ class WithJournalTableProperty(Property):
arg_types = {"this": True}
+class WithSystemVersioningProperty(Property):
+ # this -> history table name, expression -> data consistency check
+ arg_types = {"this": False, "expression": False}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -3663,6 +3698,7 @@ class DataType(Expression):
Type.BIGINT,
Type.INT128,
Type.INT256,
+ Type.BIT,
}
FLOAT_TYPES = {
@@ -3692,7 +3728,7 @@ class DataType(Expression):
@classmethod
def build(
cls,
- dtype: str | DataType | DataType.Type,
+ dtype: DATA_TYPE,
dialect: DialectType = None,
udt: bool = False,
**kwargs,
@@ -3733,7 +3769,7 @@ class DataType(Expression):
return DataType(**{**data_type_exp.args, **kwargs})
- def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
+ def is_type(self, *dtypes: DATA_TYPE) -> bool:
"""
Checks whether this DataType matches one of the provided data types. Nested types or precision
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
@@ -3761,6 +3797,9 @@ class DataType(Expression):
return False
+DATA_TYPE = t.Union[str, DataType, DataType.Type]
+
+
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(DataType):
arg_types = {"this": True}
@@ -3868,7 +3907,7 @@ class BitwiseXor(Binary):
class Div(Binary):
- pass
+ arg_types = {"this": True, "expression": True, "typed": False, "safe": False}
class Overlaps(Binary):
@@ -3892,13 +3931,25 @@ class Dot(Binary):
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
+ @property
+ def parts(self) -> t.List[Expression]:
+ """Return the parts of a table / column in order catalog, db, table."""
+ this, *parts = self.flatten()
-class DPipe(Binary):
- pass
+ parts.reverse()
+ for arg in ("this", "table", "db", "catalog"):
+ part = this.args.get(arg)
-class SafeDPipe(DPipe):
- pass
+ if isinstance(part, Expression):
+ parts.append(part)
+
+ parts.reverse()
+ return parts
+
+
+class DPipe(Binary):
+ arg_types = {"this": True, "expression": True, "safe": False}
class EQ(Binary, Predicate):
@@ -3913,6 +3964,11 @@ class NullSafeNEQ(Binary, Predicate):
pass
+# Represents e.g. := in DuckDB which is mostly used for setting parameters
+class PropertyEQ(Binary):
+ pass
+
+
class Distance(Binary):
pass
@@ -3981,6 +4037,11 @@ class NEQ(Binary, Predicate):
pass
+# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH
+class Operator(Binary):
+ arg_types = {"this": True, "operator": True, "expression": True}
+
+
class SimilarTo(Binary, Predicate):
pass
@@ -4048,7 +4109,8 @@ class Between(Predicate):
class Bracket(Condition):
- arg_types = {"this": True, "expressions": True}
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator
+ arg_types = {"this": True, "expressions": True, "offset": False, "safe": False}
@property
def output_name(self) -> str:
@@ -4058,10 +4120,6 @@ class Bracket(Condition):
return super().output_name
-class SafeBracket(Bracket):
- """Represents array lookup where OOB index yields NULL instead of causing a failure."""
-
-
class Distinct(Expression):
arg_types = {"expressions": False, "on": False}
@@ -4077,6 +4135,11 @@ class In(Predicate):
}
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in
+class ForIn(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
class TimeUnit(Expression):
"""Automatically converts unit arg into a var."""
@@ -4248,8 +4311,9 @@ class Array(Func):
# https://docs.snowflake.com/en/sql-reference/functions/to_char
+# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
class ToChar(Func):
- arg_types = {"this": True, "format": False}
+ arg_types = {"this": True, "format": False, "nlsparam": False}
class GenerateSeries(Func):
@@ -4260,6 +4324,10 @@ class ArrayAgg(AggFunc):
pass
+class ArrayUniqueAgg(AggFunc):
+ pass
+
+
class ArrayAll(Func):
arg_types = {"this": True, "expression": True}
@@ -4358,7 +4426,7 @@ class Cast(Func):
def output_name(self) -> str:
return self.name
- def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
+ def is_type(self, *dtypes: DATA_TYPE) -> bool:
"""
Checks whether this Cast's DataType matches one of the provided data types. Nested types
like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
@@ -4403,14 +4471,10 @@ class Chr(Func):
class Concat(Func):
- arg_types = {"expressions": True}
+ arg_types = {"expressions": True, "safe": False, "coalesce": False}
is_var_len_args = True
-class SafeConcat(Concat):
- pass
-
-
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
@@ -4643,6 +4707,10 @@ class If(Func):
arg_types = {"this": True, "true": True, "false": False}
+class Nullif(Func):
+ arg_types = {"this": True, "expression": True}
+
+
class Initcap(Func):
arg_types = {"this": True, "expression": False}
@@ -4651,6 +4719,10 @@ class IsNan(Func):
_sql_names = ["IS_NAN", "ISNAN"]
+class IsInf(Func):
+ _sql_names = ["IS_INF", "ISINF"]
+
+
class FormatJson(Expression):
pass
@@ -4970,10 +5042,6 @@ class SafeDivide(Func):
arg_types = {"this": True, "expression": True}
-class SetAgg(AggFunc):
- pass
-
-
class SHA(Func):
_sql_names = ["SHA", "SHA1"]
@@ -5118,6 +5186,15 @@ class Trim(Func):
class TsOrDsAdd(Func, TimeUnit):
+ # return_type is used to correctly cast the arguments of this expression when transpiling it
+ arg_types = {"this": True, "expression": True, "unit": False, "return_type": False}
+
+ @property
+ def return_type(self) -> DataType:
+ return DataType.build(self.args.get("return_type") or DataType.Type.DATE)
+
+
+class TsOrDsDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -5149,6 +5226,7 @@ class UnixToTime(Func):
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
MICROS = Literal.string("micros")
+ NANOS = Literal.string("nanos")
class UnixToTimeStr(Func):
@@ -5202,6 +5280,7 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
+FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
# Helpers
@@ -5693,7 +5772,9 @@ def delete(
if where:
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
if returning:
- delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
+ delete_expr = t.cast(
+ Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
+ )
return delete_expr
@@ -5702,6 +5783,7 @@ def insert(
into: ExpOrStr,
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
overwrite: t.Optional[bool] = None,
+ returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
@@ -5718,6 +5800,7 @@ def insert(
into: the tbl to insert data to.
columns: optionally the table's column names.
overwrite: whether to INSERT OVERWRITE or not.
+ returning: sql conditional parsed into a RETURNING statement
dialect: the dialect used to parse the input expressions.
copy: whether or not to copy the expression.
**opts: other options to use to parse the input expressions.
@@ -5739,7 +5822,12 @@ def insert(
**opts,
)
- return Insert(this=this, expression=expr, overwrite=overwrite)
+ insert = Insert(this=this, expression=expr, overwrite=overwrite)
+
+ if returning:
+ insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts))
+
+ return insert
def condition(
@@ -5913,7 +6001,7 @@ def to_identifier(name, quoted=None, copy=True):
return identifier
-def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
+def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier:
"""
Parses a given string into an identifier.
@@ -5965,7 +6053,7 @@ def to_table(sql_path: None, **kwargs) -> None:
def to_table(
- sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs
+ sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs
) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
@@ -5974,13 +6062,14 @@ def to_table(
Args:
sql_path: a `[catalog].[schema].[table]` string.
dialect: the source dialect according to which the table name will be parsed.
+ copy: Whether or not to copy a table if it is passed in.
kwargs: the kwargs to instantiate the resulting `Table` expression with.
Returns:
A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
- return sql_path
+ return maybe_copy(sql_path, copy=copy)
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
@@ -6123,7 +6212,7 @@ def column(
)
-def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast:
+def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@@ -6335,12 +6424,15 @@ def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
}
-def table_name(table: Table | str, dialect: DialectType = None) -> str:
+def table_name(table: Table | str, dialect: DialectType = None, identify: bool = False) -> str:
"""Get the full name of a table as a string.
Args:
table: Table expression node or string.
dialect: The dialect to generate the table name for.
+ identify: Determines when an identifier should be quoted. Possible values are:
+ False (default): Never quote, except in cases where it's mandatory by the dialect.
+ True: Always quote.
Examples:
>>> from sqlglot import exp, parse_one
@@ -6358,37 +6450,68 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
return ".".join(
part.sql(dialect=dialect, identify=True)
- if not SAFE_IDENTIFIER_RE.match(part.name)
+ if identify or not SAFE_IDENTIFIER_RE.match(part.name)
else part.name
for part in table.parts
)
-def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
+def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: bool = True) -> str:
+ """Returns a case normalized table name without quotes.
+
+ Args:
+ table: the table to normalize
+ dialect: the dialect to use for normalization rules
+ copy: whether or not to copy the expression.
+
+ Examples:
+ >>> normalize_table_name("`A-B`.c", dialect="bigquery")
+ 'A-B.c'
+ """
+ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
+
+ return ".".join(
+ p.name
+ for p in normalize_identifiers(
+ to_table(table, dialect=dialect, copy=copy), dialect=dialect
+ ).parts
+ )
+
+
+def replace_tables(
+ expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True
+) -> E:
"""Replace all tables in expression according to the mapping.
Args:
expression: expression node to be transformed and replaced.
mapping: mapping of table names.
+ dialect: the dialect of the mapping table
copy: whether or not to copy the expression.
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
- 'SELECT * FROM c'
+ 'SELECT * FROM c /* a.b */'
Returns:
The mapped expression.
"""
+ mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()}
+
def _replace_tables(node: Expression) -> Expression:
if isinstance(node, Table):
- new_name = mapping.get(table_name(node))
+ original = normalize_table_name(node, dialect=dialect)
+ new_name = mapping.get(original)
+
if new_name:
- return to_table(
+ table = to_table(
new_name,
- **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
+ **{k: v for k, v in node.args.items() if k not in TABLE_PARTS},
)
+ table.add_comments([original])
+ return table
return node
return expression.transform(_replace_tables, copy=copy)
@@ -6431,7 +6554,10 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
def expand(
- expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
+ expression: Expression,
+ sources: t.Dict[str, Subqueryable],
+ dialect: DialectType = None,
+ copy: bool = True,
) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
@@ -6446,15 +6572,17 @@ def expand(
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
+ dialect: The dialect of the sources dict.
copy: Whether or not to copy the expression during transformation. Defaults to True.
Returns:
The transformed expression.
"""
+ sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()}
def _expand(node: Expression):
if isinstance(node, Table):
- name = table_name(node)
+ name = normalize_table_name(node, dialect=dialect)
source = sources.get(name)
if source:
subquery = source.subquery(node.alias or name)
@@ -6465,7 +6593,7 @@ def expand(
return expression.transform(_expand, copy=copy)
-def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
+def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func:
"""
Returns a Func expression.
@@ -6479,6 +6607,7 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
Args:
name: the name of the function to build.
args: the args used to instantiate the function of interest.
+ copy: whether or not to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
@@ -6494,14 +6623,29 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
from sqlglot.dialects.dialect import Dialect
- converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
- kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
+ dialect = Dialect.get_or_raise(dialect)
- parser = Dialect.get_or_raise(dialect)().parser()
- from_args_list = parser.FUNCTIONS.get(name.upper())
+ converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args]
+ kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()}
- if from_args_list:
- function = from_args_list(converted) if converted else from_args_list.__self__(**kwargs) # type: ignore
+ constructor = dialect.parser_class.FUNCTIONS.get(name.upper())
+ if constructor:
+ if converted:
+ if "dialect" in constructor.__code__.co_varnames:
+ function = constructor(converted, dialect=dialect)
+ else:
+ function = constructor(converted)
+ elif constructor.__name__ == "from_arg_list":
+ function = constructor.__self__(**kwargs) # type: ignore
+ else:
+ constructor = FUNCTION_BY_NAME.get(name.upper())
+ if constructor:
+ function = constructor(**kwargs)
+ else:
+ raise ValueError(
+ f"Unable to convert '{name}' into a Func. Either manually construct "
+ "the Func expression of interest or parse the function call."
+ )
else:
kwargs = kwargs or {"expressions": converted}
function = Anonymous(this=name, **kwargs)
@@ -6512,6 +6656,48 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
return function
+def case(
+ expression: t.Optional[ExpOrStr] = None,
+ **opts,
+) -> Case:
+ """
+ Initialize a CASE statement.
+
+ Example:
+ case().when("a = 1", "foo").else_("bar")
+
+ Args:
+ expression: Optionally, the input expression (not all dialects support this)
+ **opts: Extra keyword arguments for parsing `expression`
+ """
+ if expression is not None:
+ this = maybe_parse(expression, **opts)
+ else:
+ this = None
+ return Case(this=this, ifs=[])
+
+
+def cast_unless(
+ expression: ExpOrStr,
+ to: DATA_TYPE,
+ *types: DATA_TYPE,
+ **opts: t.Any,
+) -> Expression | Cast:
+ """
+ Cast an expression to a data type unless it is a specified type.
+
+ Args:
+ expression: The expression to cast.
+ to: The data type to cast to.
+ **types: The types to exclude from casting.
+ **opts: Extra keyword arguments for parsing `expression`
+ """
+ expr = maybe_parse(expression, **opts)
+ if expr.is_type(*types):
+ return expr
+ return cast(expr, to, **opts)
+
+
def true() -> Boolean:
"""
Returns a true Boolean expression.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 4916cf8..f3f9060 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -9,10 +9,11 @@ from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
+ from sqlglot.dialects.dialect import DialectType
logger = logging.getLogger("sqlglot")
@@ -58,9 +59,6 @@ class Generator:
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
- exp.TsOrDsAdd: lambda self, e: self.func(
- "TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
- ),
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
@@ -108,9 +106,6 @@ class Generator:
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
}
- # Whether the base comes first
- LOG_BASE_FIRST = True
-
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
@@ -201,7 +196,7 @@ class Generator:
VALUES_AS_TABLE = True
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
- ALTER_TABLE_ADD_COLUMN_KEYWORD = True
+ ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True
# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
UNNEST_WITH_ORDINALITY = True
@@ -212,9 +207,6 @@ class Generator:
# Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
SEMI_ANTI_JOIN_WITH_SIDE = True
- # Whether or not session variables / parameters are supported, e.g. @x in T-SQL
- SUPPORTS_PARAMETERS = True
-
# Whether or not to include the type of a computed column in the CREATE DDL
COMPUTED_COLUMN_WITH_TYPE = True
@@ -230,12 +222,15 @@ class Generator:
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
- # Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
- SUPPORTS_NESTED_CTES = True
+ # Whether or not conditions require booleans WHERE x = 0 vs WHERE x
+ ENSURE_BOOLS = False
# Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
CTE_RECURSIVE_KEYWORD_REQUIRED = True
+ # Whether or not CONCAT requires >1 arguments
+ SUPPORTS_SINGLE_ARG_CONCAT = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -335,6 +330,7 @@ class Generator:
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
+ exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
}
# Keywords that can't be used as unquoted identifier names
@@ -368,36 +364,12 @@ class Generator:
exp.Paren,
)
- SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
+ # Expressions that need to have all CTEs under them bubbled up to them
+ EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
+
+ KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
- # Autofilled
- INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
- INVERSE_TIME_TRIE: t.Dict = {}
- INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- INDEX_OFFSET = 0
- UNNEST_COLUMN_ONLY = False
- ALIAS_POST_TABLESAMPLE = False
- IDENTIFIERS_CAN_START_WITH_DIGIT = False
- STRICT_STRING_CONCAT = False
- NORMALIZE_FUNCTIONS: bool | str = "upper"
- NULL_ORDERING = "nulls_are_small"
-
- can_identify: t.Callable[[str, str | bool], bool]
-
- # Delimiters for quotes, identifiers and the corresponding escape characters
- QUOTE_START = "'"
- QUOTE_END = "'"
- IDENTIFIER_START = '"'
- IDENTIFIER_END = '"'
- TOKENIZER_CLASS = Tokenizer
-
- # Delimiters for bit, hex, byte and raw literals
- BIT_START: t.Optional[str] = None
- BIT_END: t.Optional[str] = None
- HEX_START: t.Optional[str] = None
- HEX_END: t.Optional[str] = None
- BYTE_START: t.Optional[str] = None
- BYTE_END: t.Optional[str] = None
+ SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
"pretty",
@@ -411,6 +383,7 @@ class Generator:
"leading_comma",
"max_text_width",
"comments",
+ "dialect",
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
@@ -429,8 +402,10 @@ class Generator:
leading_comma: bool = False,
max_text_width: int = 80,
comments: bool = True,
+ dialect: DialectType = None,
):
import sqlglot
+ from sqlglot.dialects import Dialect
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.identify = identify
@@ -442,16 +417,19 @@ class Generator:
self.leading_comma = leading_comma
self.max_text_width = max_text_width
self.comments = comments
+ self.dialect = Dialect.get_or_raise(dialect)
# This is both a Dialect property and a Generator argument, so we prioritize the latter
self.normalize_functions = (
- self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
+ self.dialect.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
)
self.unsupported_messages: t.List[str] = []
- self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
+ self._escaped_quote_end: str = (
+ self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END
+ )
self._escaped_identifier_end: str = (
- self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
+ self.dialect.tokenizer_class.IDENTIFIER_ESCAPES[0] + self.dialect.IDENTIFIER_END
)
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
@@ -469,23 +447,14 @@ class Generator:
if copy:
expression = expression.copy()
- # Some dialects only support CTEs at the top level expression, so we need to bubble up nested
- # CTEs to that level in order to produce a syntactically valid expression. This transformation
- # happens here to minimize code duplication, since many expressions support CTEs.
- if (
- not self.SUPPORTS_NESTED_CTES
- and isinstance(expression, exp.Expression)
- and not expression.parent
- and "with" in expression.arg_types
- and any(node.parent is not expression for node in expression.find_all(exp.With))
- ):
- from sqlglot.transforms import move_ctes_to_top_level
-
- expression = move_ctes_to_top_level(expression)
+ expression = self.preprocess(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
+ if self.pretty:
+ sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
+
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@@ -495,10 +464,26 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
- if self.pretty:
- sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
+ def preprocess(self, expression: exp.Expression) -> exp.Expression:
+ """Apply generic preprocessing transformations to a given expression."""
+ if (
+ not expression.parent
+ and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES
+ and any(node.parent is not expression for node in expression.find_all(exp.With))
+ ):
+ from sqlglot.transforms import move_ctes_to_top_level
+
+ expression = move_ctes_to_top_level(expression)
+
+ if self.ENSURE_BOOLS:
+ from sqlglot.transforms import ensure_bools
+
+ expression = ensure_bools(expression)
+
+ return expression
+
def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
@@ -752,9 +737,24 @@ class Generator:
return f"GENERATED{this} AS {expr}{sequence_opts}"
+ def generatedasrowcolumnconstraint_sql(
+ self, expression: exp.GeneratedAsRowColumnConstraint
+ ) -> str:
+ start = "START" if expression.args["start"] else "END"
+ hidden = " HIDDEN" if expression.args.get("hidden") else ""
+ return f"GENERATED ALWAYS AS ROW {start}{hidden}"
+
+ def periodforsystemtimeconstraint_sql(
+ self, expression: exp.PeriodForSystemTimeConstraint
+ ) -> str:
+ return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})"
+
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
+ def transformcolumnconstraint_sql(self, expression: exp.TransformColumnConstraint) -> str:
+ return f"AS {self.sql(expression, 'this')}"
+
def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
@@ -900,32 +900,32 @@ class Generator:
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
- if not alias and not self.UNNEST_COLUMN_ONLY:
+ if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
alias = "_t"
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
- if self.BIT_START:
- return f"{self.BIT_START}{this}{self.BIT_END}"
+ if self.dialect.BIT_START:
+ return f"{self.dialect.BIT_START}{this}{self.dialect.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
- if self.HEX_START:
- return f"{self.HEX_START}{this}{self.HEX_END}"
+ if self.dialect.HEX_START:
+ return f"{self.dialect.HEX_START}{this}{self.dialect.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
- if self.BYTE_START:
- return f"{self.BYTE_START}{this}{self.BYTE_END}"
+ if self.dialect.BYTE_START:
+ return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
- return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
+ return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
this = self.sql(expression, "this")
@@ -1065,14 +1065,14 @@ class Generator:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
- text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
+ text = text.replace(self.dialect.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
- or self.can_identify(text, self.identify)
+ or self.dialect.can_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
- or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
+ or (not self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
- text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
+ text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
@@ -1121,7 +1121,7 @@ class Generator:
expressions = self.expressions(properties, sep=sep, indent=False)
if expressions:
expressions = self.wrap(expressions) if wrapped else expressions
- return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
+ return f"{prefix}{' ' if prefix.strip() else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
@@ -1286,6 +1286,21 @@ class Generator:
statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS"
return f"{data_sql}{statistics_sql}"
+ def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningProperty) -> str:
+ sql = "WITH(SYSTEM_VERSIONING=ON"
+
+ if expression.this:
+ history_table = self.sql(expression, "this")
+ sql = f"{sql}(HISTORY_TABLE={history_table}"
+
+ if expression.expression:
+ data_consistency_check = self.sql(expression, "expression")
+ sql = f"{sql}, DATA_CONSISTENCY_CHECK={data_consistency_check}"
+
+ sql = f"{sql})"
+
+ return f"{sql})"
+
def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
@@ -1387,13 +1402,13 @@ class Generator:
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
- part
- for part in [
- self.sql(expression, "catalog"),
- self.sql(expression, "db"),
- self.sql(expression, "this"),
- ]
- if part
+ self.sql(part)
+ for part in (
+ expression.args.get("catalog"),
+ expression.args.get("db"),
+ expression.args.get("this"),
+ )
+ if part is not None
)
version = self.sql(expression, "version")
@@ -1426,7 +1441,7 @@ class Generator:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
- if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
+ if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
@@ -1676,12 +1691,16 @@ class Generator:
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
- args = ", ".join(
- self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e)
+
+ args = [
+ self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e
for e in (expression.args.get(k) for k in ("offset", "expression"))
if e
- )
- return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
+ ]
+
+ args_sql = ", ".join(self.sql(e) for e in args)
+ args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
+ return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@@ -1732,13 +1751,13 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
- text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
+ text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}"
return text
def escape_str(self, text: str) -> str:
- text = text.replace(self.QUOTE_END, self._escaped_quote_end)
- if self.INVERSE_ESCAPE_SEQUENCES:
- text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
+ text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
+ if self.dialect.INVERSE_ESCAPE_SEQUENCES:
+ text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
elif self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
@@ -1782,9 +1801,11 @@ class Generator:
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
- nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
- nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
- nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
+ nulls_are_large = self.dialect.NULL_ORDERING == "nulls_are_large"
+ nulls_are_small = self.dialect.NULL_ORDERING == "nulls_are_small"
+ nulls_are_last = self.dialect.NULL_ORDERING == "nulls_are_last"
+
+ this = self.sql(expression, "this")
sort_order = " DESC" if desc else (" ASC" if desc is False else "")
nulls_sort_change = ""
@@ -1799,13 +1820,13 @@ class Generator:
):
nulls_sort_change = " NULLS LAST"
+ # If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- self.unsupported(
- "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
- )
+ null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
+ this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
- return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
+ return f"{this}{sort_order}{nulls_sort_change}"
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
@@ -1933,10 +1954,13 @@ class Generator:
)
kind = ""
+ # We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata
+ # are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first.
+ top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}"
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
- f"SELECT{top}{hint}{distinct}{kind}{expressions}",
+ f"SELECT{top_distinct}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@@ -1961,7 +1985,7 @@ class Generator:
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
- return f"{self.PARAMETER_TOKEN}{this}" if self.SUPPORTS_PARAMETERS else this
+ return f"{self.PARAMETER_TOKEN}{this}"
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
@@ -2009,7 +2033,7 @@ class Generator:
if alias and isinstance(offset, exp.Expression):
alias.append("columns", offset)
- if alias and self.UNNEST_COLUMN_ONLY:
+ if alias and self.dialect.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
@@ -2080,14 +2104,14 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
- expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
+ expressions = apply_index_offset(
+ expression.this,
+ expression.expressions,
+ self.dialect.INDEX_OFFSET - expression.args.get("offset", 0),
+ )
expressions_sql = ", ".join(self.sql(e) for e in expressions)
-
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
- def safebracket_sql(self, expression: exp.SafeBracket) -> str:
- return self.bracket_sql(expression)
-
def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
@@ -2145,12 +2169,33 @@ class Generator:
else:
return self.func("TRIM", expression.this, expression.expression)
- def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
- expressions = expression.expressions
- if self.STRICT_STRING_CONCAT:
- expressions = (exp.cast(e, "text") for e in expressions)
+ def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[exp.Expression]:
+ args = expression.expressions
+ if isinstance(expression, exp.ConcatWs):
+ args = args[1:] # Skip the delimiter
+
+ if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
+ args = [exp.cast(e, "text") for e in args]
+
+ if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"):
+ args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args]
+
+ return args
+
+ def concat_sql(self, expression: exp.Concat) -> str:
+ expressions = self.convert_concat_args(expression)
+
+ # Some dialects don't allow a single-argument CONCAT call
+ if not self.SUPPORTS_SINGLE_ARG_CONCAT and len(expressions) == 1:
+ return self.sql(expressions[0])
+
return self.func("CONCAT", *expressions)
+ def concatws_sql(self, expression: exp.ConcatWs) -> str:
+ return self.func(
+ "CONCAT_WS", seq_get(expression.expressions, 0), *self.convert_concat_args(expression)
+ )
+
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@@ -2493,14 +2538,7 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
- if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
- actions = self.expressions(
- expression,
- key="actions",
- prefix="ADD COLUMN ",
- )
- else:
- actions = f"ADD {self.expressions(expression, key='actions')}"
+ actions = self.add_column_sql(expression)
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
@@ -2512,6 +2550,15 @@ class Generator:
only = " ONLY" if expression.args.get("only") else ""
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
+ def add_column_sql(self, expression: exp.AlterTable) -> str:
+ if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
+ return self.expressions(
+ expression,
+ key="actions",
+ prefix="ADD COLUMN ",
+ )
+ return f"ADD {self.expressions(expression, key='actions', flat=True)}"
+
def droppartition_sql(self, expression: exp.DropPartition) -> str:
expressions = self.expressions(expression)
exists = " IF EXISTS " if expression.args.get("exists") else " "
@@ -2551,14 +2598,31 @@ class Generator:
)
def dpipe_sql(self, expression: exp.DPipe) -> str:
- return self.binary(expression, "||")
-
- def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
- if self.STRICT_STRING_CONCAT:
+ if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
- return self.dpipe_sql(expression)
+ return self.binary(expression, "||")
def div_sql(self, expression: exp.Div) -> str:
+ l, r = expression.left, expression.right
+
+ if not self.dialect.SAFE_DIVISION and expression.args.get("safe"):
+ r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))
+
+ if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
+ if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
+ *exp.DataType.FLOAT_TYPES
+ ):
+ l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))
+
+ elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
+ if l.is_type(*exp.DataType.INTEGER_TYPES) and r.is_type(*exp.DataType.INTEGER_TYPES):
+ return self.sql(
+ exp.cast(
+ l / r,
+ to=exp.DataType.Type.BIGINT,
+ )
+ )
+
return self.binary(expression, "/")
def overlaps_sql(self, expression: exp.Overlaps) -> str:
@@ -2573,6 +2637,9 @@ class Generator:
def eq_sql(self, expression: exp.EQ) -> str:
return self.binary(expression, "=")
+ def propertyeq_sql(self, expression: exp.PropertyEQ) -> str:
+ return self.binary(expression, ":=")
+
def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
@@ -2641,10 +2708,13 @@ class Generator:
return self.cast_sql(expression, safe_prefix="TRY_")
def log_sql(self, expression: exp.Log) -> str:
- args = list(expression.args.values())
- if not self.LOG_BASE_FIRST:
- args.reverse()
- return self.func("LOG", *args)
+ this = expression.this
+ expr = expression.expression
+
+ if not self.dialect.LOG_BASE_FIRST:
+ this, expr = expr, this
+
+ return self.func("LOG", this, expr)
def use_sql(self, expression: exp.Use) -> str:
kind = self.sql(expression, "kind")
@@ -2696,7 +2766,9 @@ class Generator:
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(
- self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
+ self.sql(expression, "format"),
+ self.dialect.INVERSE_TIME_MAPPING,
+ self.dialect.INVERSE_TIME_TRIE,
)
def expressions(
@@ -2963,6 +3035,19 @@ class Generator:
parameters = self.sql(expression, "params_struct")
return self.func("PREDICT", model, table, parameters or None)
+ def forin_sql(self, expression: exp.ForIn) -> str:
+ this = self.sql(expression, "this")
+ expression_sql = self.sql(expression, "expression")
+ return f"FOR {this} DO {expression_sql}"
+
+ def refresh_sql(self, expression: exp.Refresh) -> str:
+ this = self.sql(expression, "this")
+ table = "" if isinstance(expression.this, exp.Literal) else "TABLE "
+ return f"REFRESH {table}{this}"
+
+ def operator_sql(self, expression: exp.Operator) -> str:
+ return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
+
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
@@ -2970,3 +3055,10 @@ class Generator:
expression = simplify(expression)
return expression
+
+ def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
+ return [
+ exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
+ for value in values
+ if value
+ ]
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index ee41557..349c8c8 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import datetime
import inspect
import logging
import re
@@ -283,7 +284,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
file = open_file(read_csv.name)
delimiter = ","
- args = iter(arg.name for arg in args)
+ args = iter(arg.name for arg in args) # type: ignore
for k, v in zip(args, args):
if k == "delimiter":
delimiter = v
@@ -463,3 +464,27 @@ def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
merged.append((start, end))
return merged
+
+
+def is_iso_date(text: str) -> bool:
+ try:
+ datetime.date.fromisoformat(text)
+ return True
+ except ValueError:
+ return False
+
+
+def is_iso_datetime(text: str) -> bool:
+ try:
+ datetime.datetime.fromisoformat(text)
+ return True
+ except ValueError:
+ return False
+
+
+# Interval units that operate on date components
+DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
+
+
+def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
+ return expression is not None and expression.name.lower() in DATE_UNITS
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 011a6b8..abcc10f 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
-from sqlglot.optimizer import Scope, build_scope, qualify
+from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@@ -29,8 +29,38 @@ class Node:
else:
yield d
- def to_html(self, **opts) -> LineageHTML:
- return LineageHTML(self, **opts)
+ def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
+ nodes = {}
+ edges = []
+
+ for node in self.walk():
+ if isinstance(node.expression, exp.Table):
+ label = f"FROM {node.expression.this}"
+ title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
+ group = 1
+ else:
+ label = node.expression.sql(pretty=True, dialect=dialect)
+ source = node.source.transform(
+ lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
+ if n is node.expression
+ else n,
+ copy=False,
+ ).sql(pretty=True, dialect=dialect)
+ title = f"<pre>{source}</pre>"
+ group = 0
+
+ node_id = id(node)
+
+ nodes[node_id] = {
+ "id": node_id,
+ "label": label,
+ "title": title,
+ "group": group,
+ }
+
+ for d in node.downstream:
+ edges.append({"from": node_id, "to": id(d)})
+ return GraphHTML(nodes, edges, **opts)
def lineage(
@@ -64,6 +94,7 @@ def lineage(
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
for k, v in sources.items()
},
+ dialect=dialect,
)
qualified = qualify.qualify(
@@ -129,17 +160,6 @@ def lineage(
return upstream
- subquery = select.unalias()
-
- if isinstance(subquery, exp.Subquery):
- upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
- scope = t.cast(Scope, build_scope(subquery.unnest()))
-
- for select in subquery.named_selects:
- to_node(select, scope=scope, upstream=upstream)
-
- return upstream
-
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
@@ -156,16 +176,28 @@ def lineage(
expression=select,
alias=alias or "",
)
+
if upstream:
upstream.downstream.append(node)
+ subquery_scopes = {
+ id(subquery_scope.expression): subquery_scope
+ for subquery_scope in scope.subquery_scopes
+ }
+
+ for subquery in find_all_in_scope(select, exp.Subqueryable):
+ subquery_scope = subquery_scopes[id(subquery)]
+
+ for name in subquery.named_selects:
+ to_node(name, scope=subquery_scope, upstream=node)
+
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
# Find all columns that went into creating this one to list their lineage nodes.
- source_columns = set(select.find_all(exp.Column))
+ source_columns = set(find_all_in_scope(select, exp.Column))
# If the source is a UDTF find columns used in the UTDF to generate the table
if isinstance(source, exp.UDTF):
@@ -192,20 +224,15 @@ def lineage(
return to_node(column if isinstance(column, str) else column.name, scope)
-class LineageHTML:
+class GraphHTML:
"""Node to HTML generator using vis.js.
https://visjs.github.io/vis-network/docs/network/
"""
def __init__(
- self,
- node: Node,
- dialect: DialectType = None,
- imports: bool = True,
- **opts: t.Any,
+ self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
):
- self.node = node
self.imports = imports
self.options = {
@@ -235,39 +262,11 @@ class LineageHTML:
"maximum": 300,
},
},
- **opts,
+ **(options or {}),
}
- self.nodes = {}
- self.edges = []
-
- for node in node.walk():
- if isinstance(node.expression, exp.Table):
- label = f"FROM {node.expression.this}"
- title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
- group = 1
- else:
- label = node.expression.sql(pretty=True, dialect=dialect)
- source = node.source.transform(
- lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
- if n is node.expression
- else n,
- copy=False,
- ).sql(pretty=True, dialect=dialect)
- title = f"<pre>{source}</pre>"
- group = 0
-
- node_id = id(node)
-
- self.nodes[node_id] = {
- "id": node_id,
- "label": label,
- "title": title,
- "group": group,
- }
-
- for d in node.downstream:
- self.edges.append({"from": node_id, "to": id(d)})
+ self.nodes = nodes
+ self.edges = edges
def __str__(self):
nodes = json.dumps(list(self.nodes.values()))
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 69d4567..7b990f1 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,12 +1,18 @@
from __future__ import annotations
-import datetime
import functools
import typing as t
from sqlglot import exp
from sqlglot._typing import E
-from sqlglot.helper import ensure_list, seq_get, subclasses
+from sqlglot.helper import (
+ ensure_list,
+ is_date_unit,
+ is_iso_date,
+ is_iso_datetime,
+ seq_get,
+ subclasses,
+)
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@@ -20,10 +26,6 @@ if t.TYPE_CHECKING:
]
-# Interval units that operate on date components
-DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
-
-
def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
@@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)
-def _is_iso_date(text: str) -> bool:
- try:
- datetime.date.fromisoformat(text)
- return True
- except ValueError:
- return False
-
-
-def _is_iso_datetime(text: str) -> bool:
- try:
- datetime.datetime.fromisoformat(text)
- return True
- except ValueError:
- return False
-
-
-def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
date_text = l.name
- unit = r.text("unit").lower()
-
- is_iso_date = _is_iso_date(date_text)
+ is_iso_date_ = is_iso_date(date_text)
- if is_iso_date and unit in DATE_UNITS:
- l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
+ if is_iso_date_ and is_date_unit(unit):
return exp.DataType.Type.DATE
# An ISO date is also an ISO datetime, but not vice versa
- if is_iso_date or _is_iso_datetime(date_text):
- l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
+ if is_iso_date_ or is_iso_datetime(date_text):
return exp.DataType.Type.DATETIME
return exp.DataType.Type.UNKNOWN
-def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
- unit = r.text("unit").lower()
- if unit not in DATE_UNITS:
+def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
+ if not is_date_unit(unit):
return exp.DataType.Type.DATETIME
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
@@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Date,
exp.DateFromParts,
exp.DateStrToDate,
- exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
@@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
+ exp.Div,
exp.Exp,
exp.Ln,
exp.Log,
@@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
},
exp.DataType.Type.INT: {
exp.Ceil,
- exp.DateDiff,
exp.DatetimeDiff,
+ exp.DateDiff,
exp.Extract,
exp.TimestampDiff,
exp.TimeDiff,
@@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.GroupConcat,
exp.Initcap,
exp.Lower,
- exp.SafeConcat,
- exp.SafeDPipe,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
@@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
+ exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
@@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
- exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
- exp.DateSub: lambda self, e: self._annotate_dateadd(e),
+ exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
+ exp.DateSub: lambda self, e: self._annotate_timeunit(e),
+ exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Div: lambda self, e: self._annotate_div(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
@@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
@@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
BINARY_COERCIONS: BinaryCoercions = {
**swap_all(
{
- (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
+ (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
+ l, r.args.get("unit")
+ )
for t in exp.DataType.TEXT_TYPES
}
),
**swap_all(
{
- (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
+ # text + numeric will yield the numeric type to match most dialects' semantics
+ (text, numeric): lambda l, r: t.cast(
+ exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type
+ )
+ for text in exp.DataType.TEXT_TYPES
+ for numeric in exp.DataType.NUMERIC_TYPES
+ }
+ ),
+ **swap_all(
+ {
+ (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
+ l, r.args.get("unit")
+ ),
}
),
}
@@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
- def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
+ def _annotate_timeunit(
+ self, expression: exp.TimeUnit | exp.DateTrunc
+ ) -> exp.TimeUnit | exp.DateTrunc:
self._annotate_args(expression)
if expression.this.type.this in exp.DataType.TEXT_TYPES:
- datatype = _coerce_literal_and_interval(expression.this, expression.interval())
- elif (
- expression.this.type.is_type(exp.DataType.Type.DATE)
- and expression.text("unit").lower() not in DATE_UNITS
- ):
- datatype = exp.DataType.Type.DATETIME
+ datatype = _coerce_date_literal(expression.this, expression.unit)
+ elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
+ datatype = _coerce_date(expression.this, expression.unit)
else:
- datatype = expression.this.type
+ datatype = exp.DataType.Type.UNKNOWN
self._set_type(expression, datatype)
return expression
@@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.UNKNOWN)
return expression
+
+ def _annotate_div(self, expression: exp.Div) -> exp.Div:
+ self._annotate_args(expression)
+
+ left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
+
+ if (
+ expression.args.get("typed")
+ and left_type in exp.DataType.INTEGER_TYPES
+ and right_type in exp.DataType.INTEGER_TYPES
+ ):
+ self._set_type(expression, exp.DataType.Type.BIGINT)
+ else:
+ self._set_type(expression, self._maybe_coerce(left_type, right_type))
+
+ return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index fc5c348..faf18c6 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -1,8 +1,10 @@
from __future__ import annotations
import itertools
+import typing as t
from sqlglot import exp
+from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
def canonicalize(expression: exp.Expression) -> exp.Expression:
@@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
- expression = ensure_bool_predicates(expression)
+ expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
return expression
@@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
return node
+COERCIBLE_DATE_OPS = (
+ exp.Add,
+ exp.Sub,
+ exp.EQ,
+ exp.NEQ,
+ exp.GT,
+ exp.GTE,
+ exp.LT,
+ exp.LTE,
+ exp.NullSafeEQ,
+ exp.NullSafeNEQ,
+)
+
+
def coerce_type(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Binary):
+ if isinstance(node, COERCIBLE_DATE_OPS):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
@@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
*exp.DataType.TEMPORAL_TYPES
):
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
+ elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
+ _coerce_timeunit_arg(node.this, node.unit)
+ elif isinstance(node, exp.DateDiff):
+ _coerce_datediff_args(node)
return node
@@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
-def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
+def ensure_bools(
+ expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
+) -> exp.Expression:
if isinstance(expression, exp.Connector):
- _replace_int_predicate(expression.left)
- _replace_int_predicate(expression.right)
-
- elif isinstance(expression, (exp.Where, exp.Having)) or (
+ replace_func(expression.left)
+ replace_func(expression.right)
+ elif isinstance(expression, exp.Not):
+ replace_func(expression.this)
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
- isinstance(expression, exp.If)
- and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
+ elif isinstance(expression, exp.If) and not (
+ isinstance(expression.parent, exp.Case) and expression.parent.this
):
- _replace_int_predicate(expression.this)
+ replace_func(expression.this)
+ elif isinstance(expression, (exp.Where, exp.Having)):
+ replace_func(expression.this)
return expression
@@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
+ if isinstance(b, exp.Interval):
+ a = _coerce_timeunit_arg(a, b.unit)
if (
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
- and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
+ and b.type.this
+ not in (
+ exp.DataType.Type.DATE,
+ exp.DataType.Type.INTERVAL,
+ )
):
_replace_cast(b, exp.DataType.Type.DATE)
+def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
+ if not arg.type:
+ return arg
+
+ if arg.type.this in exp.DataType.TEXT_TYPES:
+ date_text = arg.name
+ is_iso_date_ = is_iso_date(date_text)
+
+ if is_iso_date_ and is_date_unit(unit):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
+
+ # An ISO date is also an ISO datetime, but not vice versa
+ if is_iso_date_ or is_iso_datetime(date_text):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
+
+ elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
+
+ return arg
+
+
+def _coerce_datediff_args(node: exp.DateDiff) -> None:
+ for e in (node.this, node.expression):
+ if e.type.this not in exp.DataType.TEMPORAL_TYPES:
+ e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
+
+
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))
+# this was originally designed for presto, there is a similar transform for tsql
+# this is different in that it only operates on int types, this is because
+# presto has a boolean type whereas tsql doesn't (people use bits)
+# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
if isinstance(expression, exp.Coalesce):
for _, child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
- expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
+ expression.replace(expression.neq(0))
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index b0b2b3d..a74bea7 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -186,13 +186,13 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not (
isinstance(from_or_join, exp.Join)
and inner_select.args.get("where")
- and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
+ and from_or_join.side in ("FULL", "LEFT", "RIGHT")
)
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
and any(
- j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
+ j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
)
)
and not _outer_select_joins_on_inner_select_join()
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 154256e..3361a33 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -13,7 +13,7 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
...
@@ -48,11 +48,11 @@ def normalize_identifiers(expression, dialect=None):
Returns:
The transformed expression.
"""
+ dialect = Dialect.get_or_raise(dialect)
+
if isinstance(expression, str):
expression = exp.parse_identifier(expression, dialect=dialect)
- dialect = Dialect.get_or_raise(dialect)
-
def _normalize(node: E) -> E:
if not node.meta.get("case_sensitive"):
exp.replace_children(node, _normalize)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index abac63b..1c96e95 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -42,8 +42,8 @@ RULES = (
def optimize(
expression: str | exp.Expression,
schema: t.Optional[dict | Schema] = None,
- db: t.Optional[str] = None,
- catalog: t.Optional[str] = None,
+ db: t.Optional[str | exp.Identifier] = None,
+ catalog: t.Optional[str | exp.Identifier] = None,
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index b06ea1d..742cdf5 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -8,7 +8,7 @@ from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
-from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -58,7 +58,7 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
- _qualify_outputs(scope)
+ qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
- _expand_positional_references(scope, (o.this for o in ordereds)),
+ _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
):
for agg in ordered.find_all(exp.AggFunc):
for col in agg.find_all(exp.Column):
@@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
)
-def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
- new_nodes = []
+def _expand_positional_references(
+ scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
+) -> t.List[exp.Expression]:
+ new_nodes: t.List[exp.Expression] = []
for node in expressions:
if node.is_int:
- select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
+ select = _select_by_pos(scope, t.cast(exp.Literal, node))
- if isinstance(select, exp.Literal):
- new_nodes.append(node)
+ if alias:
+ new_nodes.append(exp.column(select.args["alias"].copy()))
else:
- new_nodes.append(select.copy())
- scope.clear_cache()
+ select = select.this
+
+ if isinstance(select, exp.Literal):
+ new_nodes.append(node)
+ else:
+ new_nodes.append(select.copy())
else:
new_nodes.append(node)
@@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources and (
- not scope.parent or column_table not in scope.parent.sources
+ not scope.parent
+ or column_table not in scope.parent.sources
+ or not scope.is_correlated_subquery
):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@@ -381,15 +389,18 @@ def _expand_stars(
columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
+ table_id = id(table)
+ columns_to_exclude = except_columns.get(table_id) or set()
+
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
+ if name not in columns_to_exclude
)
continue
- table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
@@ -406,7 +417,7 @@ def _expand_stars(
copy=False,
)
)
- elif name not in except_columns.get(table_id, set()):
+ elif name not in columns_to_exclude:
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table=table)
new_selections.append(
@@ -448,10 +459,16 @@ def _add_replace_columns(
replace_columns[id(table)] = columns
-def _qualify_outputs(scope: Scope) -> None:
+def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
"""Ensure all output columns are aliased"""
- new_selections = []
+ if isinstance(scope_or_expression, exp.Expression):
+ scope = build_scope(scope_or_expression)
+ if not isinstance(scope, Scope):
+ return
+ else:
+ scope = scope_or_expression
+ new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
):
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 3a43e8f..57ecabe 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
@@ -10,9 +13,10 @@ from sqlglot.schema import Schema
def qualify_tables(
expression: E,
- db: t.Optional[str] = None,
- catalog: t.Optional[str] = None,
+ db: t.Optional[str | exp.Identifier] = None,
+ catalog: t.Optional[str | exp.Identifier] = None,
schema: t.Optional[Schema] = None,
+ dialect: DialectType = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -33,11 +37,14 @@ def qualify_tables(
db: Database name
catalog: Catalog name
schema: A schema to populate
+ dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.
"""
next_alias_name = name_sequence("_q_")
+ db = exp.parse_identifier(db, dialect=dialect) if db else None
+ catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@@ -61,9 +68,9 @@ def qualify_tables(
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
- source.set("db", exp.to_identifier(db))
+ source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", exp.to_identifier(catalog))
+ source.set("catalog", catalog)
if not source.alias:
# Mutates the source by attaching an alias to it
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 4af5b49..b7e527e 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import itertools
import logging
import typing as t
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index af03332..d4e2e60 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -507,6 +507,9 @@ def simplify_literals(expression, root=True):
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
+ if type(expression) in INVERSE_DATE_OPS:
+ return _simplify_binary(expression, expression.this, expression.interval()) or expression
+
return expression
@@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b):
return exp.null()
if a.is_number and b.is_number:
- a = int(a.name) if a.is_int else Decimal(a.name)
- b = int(b.name) if b.is_int else Decimal(b.name)
+ num_a = int(a.name) if a.is_int else Decimal(a.name)
+ num_b = int(b.name) if b.is_int else Decimal(b.name)
if isinstance(expression, exp.Add):
- return exp.Literal.number(a + b)
- if isinstance(expression, exp.Sub):
- return exp.Literal.number(a - b)
+ return exp.Literal.number(num_a + num_b)
if isinstance(expression, exp.Mul):
- return exp.Literal.number(a * b)
+ return exp.Literal.number(num_a * num_b)
+
+ # We only simplify Sub, Div if a and b have the same parent because they're not associative
+ if isinstance(expression, exp.Sub):
+ return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
if isinstance(expression, exp.Div):
# engines have differing int div behavior so intdiv is not safe
- if isinstance(a, int) and isinstance(b, int):
+ if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
return None
- return exp.Literal.number(a / b)
+ return exp.Literal.number(num_a / num_b)
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, num_a, num_b)
if boolean:
return boolean
@@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b):
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
- if isinstance(expression, exp.Add):
+ if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
return date_literal(a + b)
- if isinstance(expression, exp.Sub):
+ if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
+ elif _is_date_literal(a) and _is_date_literal(b):
+ if isinstance(expression, exp.Predicate):
+ a, b = extract_date(a), extract_date(b)
+ boolean = eval_boolean(expression, a, b)
+ if boolean:
+ return boolean
return None
@@ -590,6 +601,11 @@ def simplify_parens(expression):
return expression
+NONNULL_CONSTANTS = (
+ exp.Literal,
+ exp.Boolean,
+)
+
CONSTANTS = (
exp.Literal,
exp.Boolean,
@@ -597,11 +613,19 @@ CONSTANTS = (
)
+def _is_nonnull_constant(expression: exp.Expression) -> bool:
+ return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
+
+
+def _is_constant(expression: exp.Expression) -> bool:
+ return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
+
+
def simplify_coalesce(expression):
# COALESCE(x) -> x
if (
isinstance(expression, exp.Coalesce)
- and not expression.expressions
+ and (not expression.expressions or _is_nonnull_constant(expression.this))
# COALESCE is also used as a Spark partitioning hint
and not isinstance(expression.parent, exp.Hint)
):
@@ -621,12 +645,12 @@ def simplify_coalesce(expression):
# This transformation is valid for non-constants,
# but it really only does anything if they are both constants.
- if not isinstance(other, CONSTANTS):
+ if not _is_constant(other):
return expression
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
- if isinstance(arg, CONSTANTS):
+ if _is_constant(other):
break
else:
return expression
@@ -656,7 +680,6 @@ def simplify_coalesce(expression):
CONCATS = (exp.Concat, exp.DPipe)
-SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
@@ -672,10 +695,15 @@ def simplify_concat(expression):
sep_expr, *expressions = expression.expressions
sep = sep_expr.name
concat_type = exp.ConcatWs
+ args = {}
else:
expressions = expression.expressions
sep = ""
- concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+ concat_type = exp.Concat
+ args = {
+ "safe": expression.args.get("safe"),
+ "coalesce": expression.args.get("coalesce"),
+ }
new_args = []
for is_string_group, group in itertools.groupby(
@@ -692,7 +720,7 @@ def simplify_concat(expression):
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
- return concat_type(expressions=new_args)
+ return concat_type(expressions=new_args, **args)
def simplify_conditionals(expression):
@@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if isinstance(cast, exp.Cast):
to = cast.to
- elif isinstance(cast, exp.TsOrDsToDate):
+ elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
to = exp.DataType.build(exp.DataType.Type.DATE)
else:
return None
@@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool:
def extract_interval(expression):
- n = int(expression.name)
- unit = expression.text("unit").lower()
-
try:
+ n = int(expression.name)
+ unit = expression.text("unit").lower()
return interval(unit, n)
- except (UnsupportedUnit, ModuleNotFoundError):
+ except (UnsupportedUnit, ModuleNotFoundError, ValueError):
return None
@@ -1099,8 +1126,6 @@ GEN_MAP = {
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
- exp.DPipe: lambda e: _binary(e, "||"),
- exp.SafeDPipe: lambda e: _binary(e, "||"),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 1dab600..c7e27a3 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -13,6 +13,7 @@ from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot._typing import E
+ from sqlglot.dialects.dialect import Dialect, DialectType
logger = logging.getLogger("sqlglot")
@@ -46,6 +47,19 @@ def binary_range_parser(
)
+def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
+ # Default argument order is base, expression
+ this = seq_get(args, 0)
+ expression = seq_get(args, 1)
+
+ if expression:
+ if not dialect.LOG_BASE_FIRST:
+ this, expression = expression, this
+ return exp.Log(this=this, expression=expression)
+
+ return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
+
+
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@@ -72,13 +86,24 @@ class Parser(metaclass=_Parser):
"""
FUNCTIONS: t.Dict[str, t.Callable] = {
- **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
+ **{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()},
+ "CONCAT": lambda args, dialect: exp.Concat(
+ expressions=args,
+ safe=not dialect.STRICT_STRING_CONCAT,
+ coalesce=dialect.CONCAT_COALESCE,
+ ),
+ "CONCAT_WS": lambda args, dialect: exp.ConcatWs(
+ expressions=args,
+ safe=not dialect.STRICT_STRING_CONCAT,
+ coalesce=dialect.CONCAT_COALESCE,
+ ),
"DATE_TO_DATE_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
"LIKE": parse_like,
+ "LOG": parse_logarithm,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@@ -229,7 +254,7 @@ class Parser(metaclass=_Parser):
TokenType.SOME: exp.Any,
}
- RESERVED_KEYWORDS = {
+ RESERVED_TOKENS = {
*Tokenizer.SINGLE_TOKENS.values(),
TokenType.SELECT,
}
@@ -245,9 +270,11 @@ class Parser(metaclass=_Parser):
CREATABLES = {
TokenType.COLUMN,
+ TokenType.CONSTRAINT,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
+ TokenType.FOREIGN_KEY,
*DB_CREATABLES,
}
@@ -291,6 +318,7 @@ class Parser(metaclass=_Parser):
TokenType.NATURAL,
TokenType.NEXT,
TokenType.OFFSET,
+ TokenType.OPERATOR,
TokenType.ORDINALITY,
TokenType.OVERLAPS,
TokenType.OVERWRITE,
@@ -299,7 +327,10 @@ class Parser(metaclass=_Parser):
TokenType.PIVOT,
TokenType.PRAGMA,
TokenType.RANGE,
+ TokenType.RECURSIVE,
TokenType.REFERENCES,
+ TokenType.REFRESH,
+ TokenType.REPLACE,
TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
@@ -390,6 +421,7 @@ class Parser(metaclass=_Parser):
}
EQUALITY = {
+ TokenType.COLON_EQ: exp.PropertyEQ,
TokenType.EQ: exp.EQ,
TokenType.NEQ: exp.NEQ,
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
@@ -406,7 +438,6 @@ class Parser(metaclass=_Parser):
TokenType.AMP: exp.BitwiseAnd,
TokenType.CARET: exp.BitwiseXor,
TokenType.PIPE: exp.BitwiseOr,
- TokenType.DPIPE: exp.DPipe,
}
TERM = {
@@ -423,6 +454,8 @@ class Parser(metaclass=_Parser):
TokenType.STAR: exp.Mul,
}
+ EXPONENT: t.Dict[TokenType, t.Type[exp.Expression]] = {}
+
TIMES = {
TokenType.TIME,
TokenType.TIMETZ,
@@ -558,6 +591,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.PIVOT: lambda self: self._parse_simplified_pivot(),
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
+ TokenType.REFRESH: lambda self: self._parse_refresh(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
@@ -697,6 +731,7 @@ class Parser(metaclass=_Parser):
exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
"STORED": lambda self: self._parse_stored(),
+ "SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"TEMP": lambda self: self.expression(exp.TemporaryProperty),
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
@@ -754,6 +789,7 @@ class Parser(metaclass=_Parser):
)
or self.expression(exp.OnProperty, this=self._parse_id_var()),
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
+ "PERIOD": lambda self: self._parse_period_for_system_time(),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
"REFERENCES": lambda self: self._parse_references(match=False),
"TITLE": lambda self: self.expression(
@@ -775,7 +811,7 @@ class Parser(metaclass=_Parser):
"RENAME": lambda self: self._parse_alter_table_rename(),
}
- SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
+ SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"}
NO_PAREN_FUNCTION_PARSERS = {
"ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
@@ -794,14 +830,11 @@ class Parser(metaclass=_Parser):
FUNCTION_PARSERS = {
"ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
- "CONCAT": lambda self: self._parse_concat(),
- "CONCAT_WS": lambda self: self._parse_concat_ws(),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"JSON_TABLE": lambda self: self._parse_json_table(),
- "LOG": lambda self: self._parse_logarithm(),
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
@@ -877,6 +910,7 @@ class Parser(metaclass=_Parser):
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
+ OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
@@ -896,17 +930,13 @@ class Parser(metaclass=_Parser):
STRICT_CAST = True
- # A NULL arg in CONCAT yields NULL by default
- CONCAT_NULL_OUTPUTS_STRING = False
-
PREFIXED_PIVOT_COLUMNS = False
IDENTIFY_PIVOT_STRINGS = False
- LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
# Whether or not ADD is present for each column added by ALTER TABLE
- ALTER_TABLE_ADD_COLUMN_KEYWORD = True
+ ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True
# Whether or not the table sample clause expects CSV syntax
TABLESAMPLE_CSV = False
@@ -921,6 +951,7 @@ class Parser(metaclass=_Parser):
"error_level",
"error_message_context",
"max_errors",
+ "dialect",
"sql",
"errors",
"_tokens",
@@ -929,35 +960,25 @@ class Parser(metaclass=_Parser):
"_next",
"_prev",
"_prev_comments",
- "_tokenizer",
)
# Autofilled
- TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer
- INDEX_OFFSET: int = 0
- UNNEST_COLUMN_ONLY: bool = False
- ALIAS_POST_TABLESAMPLE: bool = False
- STRICT_STRING_CONCAT = False
- SUPPORTS_USER_DEFINED_TYPES = True
- NORMALIZE_FUNCTIONS = "upper"
- NULL_ORDERING: str = "nulls_are_small"
SHOW_TRIE: t.Dict = {}
SET_TRIE: t.Dict = {}
- FORMAT_MAPPING: t.Dict[str, str] = {}
- FORMAT_TRIE: t.Dict = {}
- TIME_MAPPING: t.Dict[str, str] = {}
- TIME_TRIE: t.Dict = {}
def __init__(
self,
error_level: t.Optional[ErrorLevel] = None,
error_message_context: int = 100,
max_errors: int = 3,
+ dialect: DialectType = None,
):
+ from sqlglot.dialects import Dialect
+
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
self.max_errors = max_errors
- self._tokenizer = self.TOKENIZER_CLASS()
+ self.dialect = Dialect.get_or_raise(dialect)
self.reset()
def reset(self):
@@ -1384,7 +1405,7 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.CLONE_KEYWORDS):
copy = self._prev.text.lower() == "copy"
clone = self._parse_table(schema=True)
- when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
+ when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper()
clone_kind = (
self._match(TokenType.L_PAREN)
and self._match_texts(self.CLONE_KINDS)
@@ -1524,6 +1545,22 @@ class Parser(metaclass=_Parser):
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
+ def _parse_system_versioning_property(self) -> exp.WithSystemVersioningProperty:
+ self._match_pair(TokenType.EQ, TokenType.ON)
+
+ prop = self.expression(exp.WithSystemVersioningProperty)
+ if self._match(TokenType.L_PAREN):
+ self._match_text_seq("HISTORY_TABLE", "=")
+ prop.set("this", self._parse_table_parts())
+
+ if self._match(TokenType.COMMA):
+ self._match_text_seq("DATA_CONSISTENCY_CHECK", "=")
+ prop.set("expression", self._advance_any() and self._prev.text.upper())
+
+ self._match_r_paren()
+
+ return prop
+
def _parse_with_property(
self,
) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
@@ -2140,7 +2177,11 @@ class Parser(metaclass=_Parser):
return self._parse_expressions()
def _parse_select(
- self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
+ self,
+ nested: bool = False,
+ table: bool = False,
+ parse_subquery_alias: bool = True,
+ parse_set_operation: bool = True,
) -> t.Optional[exp.Expression]:
cte = self._parse_with()
@@ -2216,7 +2257,11 @@ class Parser(metaclass=_Parser):
t.cast(exp.From, self._parse_from(skip_from_token=True))
)
else:
- this = self._parse_table() if table else self._parse_select(nested=True)
+ this = (
+ self._parse_table()
+ if table
+ else self._parse_select(nested=True, parse_set_operation=False)
+ )
this = self._parse_set_operations(self._parse_query_modifiers(this))
self._match_r_paren()
@@ -2235,7 +2280,9 @@ class Parser(metaclass=_Parser):
else:
this = None
- return self._parse_set_operations(this)
+ if parse_set_operation:
+ return self._parse_set_operations(this)
+ return this
def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]:
if not skip_with_token and not self._match(TokenType.WITH):
@@ -2563,9 +2610,8 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
return this
- opclass = self._parse_var(any_token=True)
- if opclass:
- return self.expression(exp.Opclass, this=this, expression=opclass)
+ if not self._match_set(self.OPTYPE_FOLLOW_TOKENS, advance=False):
+ return self.expression(exp.Opclass, this=this, expression=self._parse_table_parts())
return this
@@ -2630,7 +2676,7 @@ class Parser(metaclass=_Parser):
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
hint = exp.IndexTableHint(this=self._prev.text.upper())
- self._match_texts({"INDEX", "KEY"})
+ self._match_texts(("INDEX", "KEY"))
if self._match(TokenType.FOR):
hint.set("target", self._advance_any() and self._prev.text.upper())
@@ -2650,7 +2696,7 @@ class Parser(metaclass=_Parser):
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
catalog = None
db = None
- table = self._parse_table_part(schema=schema)
+ table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
while self._match(TokenType.DOT):
if catalog:
@@ -2661,7 +2707,7 @@ class Parser(metaclass=_Parser):
else:
catalog = db
db = table
- table = self._parse_table_part(schema=schema)
+ table = self._parse_table_part(schema=schema) or ""
if not table:
self.raise_error(f"Expected table name but got {self._curr}")
@@ -2709,7 +2755,7 @@ class Parser(metaclass=_Parser):
if version:
this.set("version", version)
- if self.ALIAS_POST_TABLESAMPLE:
+ if self.dialect.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
@@ -2724,7 +2770,7 @@ class Parser(metaclass=_Parser):
if not this.args.get("pivots"):
this.set("pivots", self._parse_pivots())
- if not self.ALIAS_POST_TABLESAMPLE:
+ if not self.dialect.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
if table_sample:
@@ -2776,13 +2822,13 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST):
return None
- expressions = self._parse_wrapped_csv(self._parse_type)
+ expressions = self._parse_wrapped_csv(self._parse_equality)
offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
alias = self._parse_table_alias() if with_alias else None
if alias:
- if self.UNNEST_COLUMN_ONLY:
+ if self.dialect.UNNEST_COLUMN_ONLY:
if alias.args.get("columns"):
self.raise_error("Unexpected extra column alias in unnest.")
@@ -2845,7 +2891,7 @@ class Parser(metaclass=_Parser):
num = (
self._parse_factor()
if self._match(TokenType.NUMBER, advance=False)
- else self._parse_primary()
+ else self._parse_primary() or self._parse_placeholder()
)
if self._match_text_seq("BUCKET"):
@@ -3108,10 +3154,10 @@ class Parser(metaclass=_Parser):
if (
not explicitly_null_ordered
and (
- (not desc and self.NULL_ORDERING == "nulls_are_small")
- or (desc and self.NULL_ORDERING != "nulls_are_small")
+ (not desc and self.dialect.NULL_ORDERING == "nulls_are_small")
+ or (desc and self.dialect.NULL_ORDERING != "nulls_are_small")
)
- and self.NULL_ORDERING != "nulls_are_last"
+ and self.dialect.NULL_ORDERING != "nulls_are_last"
):
nulls_first = True
@@ -3124,7 +3170,7 @@ class Parser(metaclass=_Parser):
comments = self._prev_comments
if top:
limit_paren = self._match(TokenType.L_PAREN)
- expression = self._parse_number()
+ expression = self._parse_term() if limit_paren else self._parse_number()
if limit_paren:
self._match_r_paren()
@@ -3225,7 +3271,9 @@ class Parser(metaclass=_Parser):
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
by_name=self._match_text_seq("BY", "NAME"),
- expression=self._parse_set_operations(self._parse_select(nested=True)),
+ expression=self._parse_set_operations(
+ self._parse_select(nested=True, parse_set_operation=False)
+ ),
)
def _parse_expression(self) -> t.Optional[exp.Expression]:
@@ -3287,7 +3335,8 @@ class Parser(metaclass=_Parser):
unnest = self._parse_unnest(with_alias=False)
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
- elif self._match(TokenType.L_PAREN):
+ elif self._match_set((TokenType.L_PAREN, TokenType.L_BRACKET)):
+ matched_l_paren = self._prev.token_type == TokenType.L_PAREN
expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
@@ -3295,13 +3344,16 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.In, this=this, expressions=expressions)
- self._match_r_paren(this)
+ if matched_l_paren:
+ self._match_r_paren(this)
+ elif not self._match(TokenType.R_BRACKET, expression=this):
+ self.raise_error("Expecting ]")
else:
this = self.expression(exp.In, this=this, field=self._parse_field())
return this
- def _parse_between(self, this: exp.Expression) -> exp.Between:
+ def _parse_between(self, this: t.Optional[exp.Expression]) -> exp.Between:
low = self._parse_bitwise()
self._match(TokenType.AND)
high = self._parse_bitwise()
@@ -3357,6 +3409,13 @@ class Parser(metaclass=_Parser):
this=this,
expression=self._parse_term(),
)
+ elif self.dialect.DPIPE_IS_STRING_CONCAT and self._match(TokenType.DPIPE):
+ this = self.expression(
+ exp.DPipe,
+ this=this,
+ expression=self._parse_term(),
+ safe=not self.dialect.STRICT_STRING_CONCAT,
+ )
elif self._match(TokenType.DQMARK):
this = self.expression(exp.Coalesce, this=this, expressions=self._parse_term())
elif self._match_pair(TokenType.LT, TokenType.LT):
@@ -3376,7 +3435,17 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_factor, self.TERM)
def _parse_factor(self) -> t.Optional[exp.Expression]:
- return self._parse_tokens(self._parse_unary, self.FACTOR)
+ if self.EXPONENT:
+ factor = self._parse_tokens(self._parse_exponent, self.FACTOR)
+ else:
+ factor = self._parse_tokens(self._parse_unary, self.FACTOR)
+ if isinstance(factor, exp.Div):
+ factor.args["typed"] = self.dialect.TYPED_DIVISION
+ factor.args["safe"] = self.dialect.SAFE_DIVISION
+ return factor
+
+ def _parse_exponent(self) -> t.Optional[exp.Expression]:
+ return self._parse_tokens(self._parse_unary, self.EXPONENT)
def _parse_unary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.UNARY_PARSERS):
@@ -3427,14 +3496,14 @@ class Parser(metaclass=_Parser):
)
if identifier:
- tokens = self._tokenizer.tokenize(identifier.name)
+ tokens = self.dialect.tokenize(identifier.name)
if len(tokens) != 1:
self.raise_error("Unexpected identifier", self._prev)
if tokens[0].token_type in self.TYPE_TOKENS:
self._prev = tokens[0]
- elif self.SUPPORTS_USER_DEFINED_TYPES:
+ elif self.dialect.SUPPORTS_USER_DEFINED_TYPES:
type_name = identifier.name
while self._match(TokenType.DOT):
@@ -3713,6 +3782,7 @@ class Parser(metaclass=_Parser):
if not self._curr:
return None
+ comments = self._curr.comments
token_type = self._curr.token_type
this = self._curr.text
upper = this.upper()
@@ -3754,13 +3824,22 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
if function and not anonymous:
- func = self.validate_expression(function(args), args)
- if not self.NORMALIZE_FUNCTIONS:
+ if "dialect" in function.__code__.co_varnames:
+ func = function(args, dialect=self.dialect)
+ else:
+ func = function(args)
+
+ func = self.validate_expression(func, args)
+ if not self.dialect.NORMALIZE_FUNCTIONS:
func.meta["name"] = this
+
this = func
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
+ if isinstance(this, exp.Expression):
+ this.add_comments(comments)
+
self._match_r_paren(this)
return self._parse_window(this)
@@ -3875,6 +3954,11 @@ class Parser(metaclass=_Parser):
not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
)
)
+ elif kind and self._match_pair(TokenType.ALIAS, TokenType.L_PAREN, advance=False):
+ self._match(TokenType.ALIAS)
+ constraints.append(
+ self.expression(exp.TransformColumnConstraint, this=self._parse_field())
+ )
while True:
constraint = self._parse_column_constraint()
@@ -3917,7 +4001,11 @@ class Parser(metaclass=_Parser):
def _parse_generated_as_identity(
self,
- ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
+ ) -> (
+ exp.GeneratedAsIdentityColumnConstraint
+ | exp.ComputedColumnConstraint
+ | exp.GeneratedAsRowColumnConstraint
+ ):
if self._match_text_seq("BY", "DEFAULT"):
on_null = self._match_pair(TokenType.ON, TokenType.NULL)
this = self.expression(
@@ -3928,6 +4016,14 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match(TokenType.ALIAS)
+
+ if self._match_text_seq("ROW"):
+ start = self._match_text_seq("START")
+ if not start:
+ self._match(TokenType.END)
+ hidden = self._match_text_seq("HIDDEN")
+ return self.expression(exp.GeneratedAsRowColumnConstraint, start=start, hidden=hidden)
+
identity = self._match_text_seq("IDENTITY")
if self._match(TokenType.L_PAREN):
@@ -4100,6 +4196,16 @@ class Parser(metaclass=_Parser):
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
return self._parse_field()
+ def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint:
+ self._match(TokenType.TIMESTAMP_SNAPSHOT)
+
+ id_vars = self._parse_wrapped_id_vars()
+ return self.expression(
+ exp.PeriodForSystemTimeConstraint,
+ this=seq_get(id_vars, 0),
+ expression=seq_get(id_vars, 1),
+ )
+
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
@@ -4145,7 +4251,7 @@ class Parser(metaclass=_Parser):
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
- expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET)
+ expressions = apply_index_offset(this, expressions, -self.dialect.INDEX_OFFSET)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
self._add_comments(this)
@@ -4259,8 +4365,8 @@ class Parser(metaclass=_Parser):
format=exp.Literal.string(
format_time(
fmt_string.this if fmt_string else "",
- self.FORMAT_MAPPING or self.TIME_MAPPING,
- self.FORMAT_TRIE or self.TIME_TRIE,
+ self.dialect.FORMAT_MAPPING or self.dialect.TIME_MAPPING,
+ self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE,
)
),
)
@@ -4280,30 +4386,6 @@ class Parser(metaclass=_Parser):
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
)
- def _parse_concat(self) -> t.Optional[exp.Expression]:
- args = self._parse_csv(self._parse_conjunction)
- if self.CONCAT_NULL_OUTPUTS_STRING:
- args = self._ensure_string_if_null(args)
-
- # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
- # we find such a call we replace it with its argument.
- if len(args) == 1:
- return args[0]
-
- return self.expression(
- exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args
- )
-
- def _parse_concat_ws(self) -> t.Optional[exp.Expression]:
- args = self._parse_csv(self._parse_conjunction)
- if len(args) < 2:
- return self.expression(exp.ConcatWs, expressions=args)
- delim, *values = args
- if self.CONCAT_NULL_OUTPUTS_STRING:
- values = self._ensure_string_if_null(values)
-
- return self.expression(exp.ConcatWs, expressions=[delim] + values)
-
def _parse_string_agg(self) -> exp.Expression:
if self._match(TokenType.DISTINCT):
args: t.List[t.Optional[exp.Expression]] = [
@@ -4495,19 +4577,6 @@ class Parser(metaclass=_Parser):
empty_handling=empty_handling,
)
- def _parse_logarithm(self) -> exp.Func:
- # Default argument order is base, expression
- args = self._parse_csv(self._parse_range)
-
- if len(args) > 1:
- if not self.LOG_BASE_FIRST:
- args.reverse()
- return exp.Log.from_arg_list(args)
-
- return self.expression(
- exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
- )
-
def _parse_match_against(self) -> exp.MatchAgainst:
expressions = self._parse_csv(self._parse_column)
@@ -4755,6 +4824,7 @@ class Parser(metaclass=_Parser):
self, this: t.Optional[exp.Expression], explicit: bool = False
) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
+ comments = self._prev_comments
if explicit and not any_token:
return this
@@ -4762,6 +4832,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
aliases = self.expression(
exp.Aliases,
+ comments=comments,
this=this,
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
)
@@ -4771,7 +4842,7 @@ class Parser(metaclass=_Parser):
alias = self._parse_id_var(any_token)
if alias:
- return self.expression(exp.Alias, this=this, alias=alias)
+ return self.expression(exp.Alias, comments=comments, this=this, alias=alias)
return this
@@ -4792,8 +4863,8 @@ class Parser(metaclass=_Parser):
return None
def _parse_string(self) -> t.Optional[exp.Expression]:
- if self._match(TokenType.STRING):
- return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
+ if self._match_set((TokenType.STRING, TokenType.RAW_STRING)):
+ return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
return self._parse_placeholder()
def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
@@ -4821,7 +4892,7 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _advance_any(self) -> t.Optional[Token]:
- if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
+ if self._curr and self._curr.token_type not in self.RESERVED_TOKENS:
self._advance()
return self._prev
return None
@@ -4951,7 +5022,7 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
- self._match_texts({"TRANSACTION", "WORK"})
+ self._match_texts(("TRANSACTION", "WORK"))
modes = []
while True:
@@ -4971,7 +5042,7 @@ class Parser(metaclass=_Parser):
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
- self._match_texts({"TRANSACTION", "WORK"})
+ self._match_texts(("TRANSACTION", "WORK"))
if self._match_text_seq("TO"):
self._match_text_seq("SAVEPOINT")
@@ -4986,6 +5057,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Commit, chain=chain)
+ def _parse_refresh(self) -> exp.Refresh:
+ self._match(TokenType.TABLE)
+ return self.expression(exp.Refresh, this=self._parse_string() or self._parse_table())
+
def _parse_add_column(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("ADD"):
return None
@@ -5050,10 +5125,9 @@ class Parser(metaclass=_Parser):
return self._parse_csv(self._parse_add_constraint)
self._retreat(index)
- if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"):
- return self._parse_csv(self._parse_field_def)
-
- return self._parse_csv(self._parse_add_column)
+ if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"):
+ return self._parse_wrapped_csv(self._parse_field_def, optional=True)
+ return self._parse_wrapped_csv(self._parse_add_column, optional=True)
def _parse_alter_table_alter(self) -> exp.AlterColumn:
self._match(TokenType.COLUMN)
@@ -5198,7 +5272,7 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
index = self._index
- if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
+ if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var()
@@ -5292,7 +5366,9 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.DictRange, this=this, min=min, max=max)
- def _parse_comprehension(self, this: exp.Expression) -> t.Optional[exp.Comprehension]:
+ def _parse_comprehension(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Comprehension]:
index = self._index
expression = self._parse_column()
if not self._match(TokenType.IN):
@@ -5441,10 +5517,3 @@ class Parser(metaclass=_Parser):
else:
column.replace(dot_or_id)
return node
-
- def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
- return [
- exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
- for value in values
- if value
- ]
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index acf9bc4..54c08dd 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -15,8 +15,6 @@ if t.TYPE_CHECKING:
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
-TABLE_ARGS = ("this", "db", "catalog")
-
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@@ -147,7 +145,7 @@ class AbstractMappingSchema:
if not depth: # None
self._supported_table_args = tuple()
elif 1 <= depth <= 3:
- self._supported_table_args = TABLE_ARGS[:depth]
+ self._supported_table_args = exp.TABLE_PARTS[:depth]
else:
raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
@@ -156,7 +154,7 @@ class AbstractMappingSchema:
def table_parts(self, table: exp.Table) -> t.List[str]:
if isinstance(table.this, exp.ReadCSV):
return [table.this.name]
- return [table.text(part) for part in TABLE_ARGS if table.text(part)]
+ return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
@@ -365,13 +363,11 @@ class MappingSchema(AbstractMappingSchema, Schema):
f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
)
- normalized_keys = [
- self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
- ]
+ normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
for column_name, column_type in columns.items():
nested_set(
normalized_mapping,
- normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
+ normalized_keys + [self._normalize_name(column_name)],
column_type,
)
@@ -383,21 +379,19 @@ class MappingSchema(AbstractMappingSchema, Schema):
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.Table:
- normalized_table = exp.maybe_parse(
- table, into=exp.Table, dialect=dialect or self.dialect, copy=True
- )
+ dialect = dialect or self.dialect
+ normalize = self.normalize if normalize is None else normalize
- for arg in TABLE_ARGS:
- value = normalized_table.args.get(arg)
- if isinstance(value, (str, exp.Identifier)):
- normalized_table.set(
- arg,
- exp.to_identifier(
- self._normalize_name(
- value, dialect=dialect, is_table=True, normalize=normalize
- )
- ),
- )
+ normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
+
+ if normalize:
+ for arg in exp.TABLE_PARTS:
+ value = normalized_table.args.get(arg)
+ if isinstance(value, exp.Identifier):
+ normalized_table.set(
+ arg,
+ normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
+ )
return normalized_table
@@ -413,7 +407,7 @@ class MappingSchema(AbstractMappingSchema, Schema):
dialect=dialect or self.dialect,
is_table=is_table,
normalize=self.normalize if normalize is None else normalize,
- )
+ ).name
def depth(self) -> int:
if not self.empty and not self._depth:
@@ -451,16 +445,16 @@ def normalize_name(
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = True,
-) -> str:
+) -> exp.Identifier:
if isinstance(identifier, str):
identifier = exp.parse_identifier(identifier, dialect=dialect)
if not normalize:
- return identifier.name
+ return identifier
- # This can be useful for normalize_identifier
+ # this is used for normalize_identifier, bigquery has special rules pertaining tables
identifier.meta["is_table"] = is_table
- return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
+ return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
diff --git a/sqlglot/time.py b/sqlglot/time.py
index c286ec1..50ec2ec 100644
--- a/sqlglot/time.py
+++ b/sqlglot/time.py
@@ -42,6 +42,10 @@ def format_time(
end -= 1
chars = sym
sym = None
+ else:
+ chars = chars[0]
+ end = start + 1
+
start += len(chars)
chunks.append(chars)
current = trie
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 9784c63..e4c3204 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -7,6 +7,9 @@ from sqlglot.errors import TokenError
from sqlglot.helper import AutoName
from sqlglot.trie import TrieResult, in_trie, new_trie
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
class TokenType(AutoName):
L_PAREN = auto()
@@ -34,6 +37,7 @@ class TokenType(AutoName):
EQ = auto()
NEQ = auto()
NULLSAFE_EQ = auto()
+ COLON_EQ = auto()
AND = auto()
OR = auto()
AMP = auto()
@@ -56,6 +60,7 @@ class TokenType(AutoName):
SESSION_PARAMETER = auto()
DAMP = auto()
XOR = auto()
+ DSTAR = auto()
BLOCK_START = auto()
BLOCK_END = auto()
@@ -274,6 +279,7 @@ class TokenType(AutoName):
OBJECT_IDENTIFIER = auto()
OFFSET = auto()
ON = auto()
+ OPERATOR = auto()
ORDER_BY = auto()
ORDERED = auto()
ORDINALITY = auto()
@@ -295,6 +301,7 @@ class TokenType(AutoName):
QUOTE = auto()
RANGE = auto()
RECURSIVE = auto()
+ REFRESH = auto()
REPLACE = auto()
RETURNING = auto()
REFERENCES = auto()
@@ -371,7 +378,7 @@ class Token:
col: int = 1,
start: int = 0,
end: int = 0,
- comments: t.List[str] = [],
+ comments: t.Optional[t.List[str]] = None,
) -> None:
"""Token initializer.
@@ -390,7 +397,7 @@ class Token:
self.col = col
self.start = start
self.end = end
- self.comments = comments
+ self.comments = [] if comments is None else comments
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
@@ -497,11 +504,8 @@ class Tokenizer(metaclass=_Tokenizer):
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
- ESCAPE_SEQUENCES: t.Dict[str, str] = {}
# Autofilled
- IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
-
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
_IDENTIFIERS: t.Dict[str, str] = {}
@@ -523,6 +527,7 @@ class Tokenizer(metaclass=_Tokenizer):
"<=": TokenType.LTE,
"<>": TokenType.NEQ,
"!=": TokenType.NEQ,
+ ":=": TokenType.COLON_EQ,
"<=>": TokenType.NULLSAFE_EQ,
"->": TokenType.ARROW,
"->>": TokenType.DARROW,
@@ -689,17 +694,22 @@ class Tokenizer(metaclass=_Tokenizer):
"BOOLEAN": TokenType.BOOLEAN,
"BYTE": TokenType.TINYINT,
"MEDIUMINT": TokenType.MEDIUMINT,
+ "INT1": TokenType.TINYINT,
"TINYINT": TokenType.TINYINT,
+ "INT16": TokenType.SMALLINT,
"SHORT": TokenType.SMALLINT,
"SMALLINT": TokenType.SMALLINT,
"INT128": TokenType.INT128,
+ "HUGEINT": TokenType.INT128,
"INT2": TokenType.SMALLINT,
"INTEGER": TokenType.INT,
"INT": TokenType.INT,
"INT4": TokenType.INT,
+ "INT32": TokenType.INT,
+ "INT64": TokenType.BIGINT,
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
- "INT8": TokenType.BIGINT,
+ "INT8": TokenType.TINYINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"BIGDECIMAL": TokenType.BIGDECIMAL,
@@ -781,7 +791,6 @@ class Tokenizer(metaclass=_Tokenizer):
"\t": TokenType.SPACE,
"\n": TokenType.BREAK,
"\r": TokenType.BREAK,
- "\r\n": TokenType.BREAK,
}
COMMANDS = {
@@ -803,6 +812,7 @@ class Tokenizer(metaclass=_Tokenizer):
"sql",
"size",
"tokens",
+ "dialect",
"_start",
"_current",
"_line",
@@ -814,7 +824,10 @@ class Tokenizer(metaclass=_Tokenizer):
"_prev_token_line",
)
- def __init__(self) -> None:
+ def __init__(self, dialect: DialectType = None) -> None:
+ from sqlglot.dialects import Dialect
+
+ self.dialect = Dialect.get_or_raise(dialect)
self.reset()
def reset(self) -> None:
@@ -850,13 +863,26 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
while self.size and not self._end:
- self._start = self._current
- self._advance()
+ current = self._current
+
+ # skip spaces inline rather than iteratively call advance()
+ # for performance reasons
+ while current < self.size:
+ char = self.sql[current]
+
+ if char.isspace() and (char == " " or char == "\t"):
+ current += 1
+ else:
+ break
+
+ n = current - self._current
+ self._start = current
+ self._advance(n if n > 1 else 1)
if self._char is None:
break
- if self._char not in self.WHITE_SPACE:
+ if not self._char.isspace():
if self._char.isdigit():
self._scan_number()
elif self._char in self._IDENTIFIERS:
@@ -881,6 +907,10 @@ class Tokenizer(metaclass=_Tokenizer):
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
+ # Ensures we don't count an extra line if we get a \r\n line break sequence
+ if self._char == "\r" and self._peek == "\n":
+ i = 2
+
self._col = 1
self._line += 1
else:
@@ -982,7 +1012,7 @@ class Tokenizer(metaclass=_Tokenizer):
if end < self.size:
char = self.sql[end]
single_token = single_token or char in self.SINGLE_TOKENS
- is_space = char in self.WHITE_SPACE
+ is_space = char.isspace()
if not is_space or not prev_space:
if is_space:
@@ -994,7 +1024,7 @@ class Tokenizer(metaclass=_Tokenizer):
skip = True
else:
char = ""
- chars = " "
+ break
if word:
if self._scan_string(word):
@@ -1086,7 +1116,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
- elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
+ elif self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
self._advance(-len(literal))
@@ -1208,8 +1238,12 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
- if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES:
- escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek)
+ if (
+ self.dialect.ESCAPE_SEQUENCES
+ and self._peek
+ and self._char in self.STRING_ESCAPES
+ ):
+ escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek)
if escaped_sequence:
self._advance(2)
text += escaped_sequence
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 445fda6..03acc2b 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -141,7 +141,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
- """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
+ """Convert cross join unnest into lateral view explode."""
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
unnest = join.this
@@ -166,7 +166,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
- """Convert explode/posexplode into unnest (used in hive -> presto)."""
+ """Convert explode/posexplode into unnest."""
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Select):
@@ -199,11 +199,11 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
explode_alias = ""
if isinstance(select, exp.Alias):
- explode_alias = select.alias
+ explode_alias = select.args["alias"]
alias = select
elif isinstance(select, exp.Aliases):
- pos_alias = select.aliases[0].name
- explode_alias = select.aliases[1].name
+ pos_alias = select.aliases[0]
+ explode_alias = select.aliases[1]
alias = select.replace(exp.alias_(select.this, "", copy=False))
else:
alias = select.replace(exp.alias_(select, ""))
@@ -230,9 +230,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
alias.set("alias", exp.to_identifier(explode_alias))
+ series_table_alias = series.args["alias"].this
column = exp.If(
- this=exp.column(series_alias).eq(exp.column(pos_alias)),
- true=exp.column(explode_alias),
+ this=exp.column(series_alias, table=series_table_alias).eq(
+ exp.column(pos_alias, table=unnest_source_alias)
+ ),
+ true=exp.column(explode_alias, table=unnest_source_alias),
)
explode.replace(column)
@@ -242,8 +245,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
expressions.insert(
expressions.index(alias) + 1,
exp.If(
- this=exp.column(series_alias).eq(exp.column(pos_alias)),
- true=exp.column(pos_alias),
+ this=exp.column(series_alias, table=series_table_alias).eq(
+ exp.column(pos_alias, table=unnest_source_alias)
+ ),
+ true=exp.column(pos_alias, table=unnest_source_alias),
).as_(pos_alias),
)
expression.set("expressions", expressions)
@@ -276,10 +281,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
size = size - 1
expression.where(
- exp.column(series_alias)
- .eq(exp.column(pos_alias))
+ exp.column(series_alias, table=series_table_alias)
+ .eq(exp.column(pos_alias, table=unnest_source_alias))
.or_(
- (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
+ (exp.column(series_alias, table=series_table_alias) > size).and_(
+ exp.column(pos_alias, table=unnest_source_alias).eq(size)
+ )
),
copy=False,
)
@@ -386,14 +393,16 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
full_outer_joins = [
(index, join)
for index, join in enumerate(expression.args.get("joins") or [])
- if join.side == "FULL" and join.kind == "OUTER"
+ if join.side == "FULL"
]
if len(full_outer_joins) == 1:
expression_copy = expression.copy()
+ expression.set("limit", None)
index, full_outer_join = full_outer_joins[0]
full_outer_join.set("side", "left")
expression_copy.args["joins"][index].set("side", "right")
+ expression_copy.args.pop("with", None) # remove CTEs from RIGHT side
return exp.union(expression, expression_copy, copy=False)
@@ -430,6 +439,33 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
return expression
+def ensure_bools(expression: exp.Expression) -> exp.Expression:
+ """Converts numeric values used in conditions into explicit boolean expressions."""
+ from sqlglot.optimizer.canonicalize import ensure_bools
+
+ def _ensure_bool(node: exp.Expression) -> None:
+ if (
+ node.is_number
+ or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
+ or (isinstance(node, exp.Column) and not node.type)
+ ):
+ node.replace(node.neq(0))
+
+ for node, *_ in expression.walk():
+ ensure_bools(node, _ensure_bool)
+
+ return expression
+
+
+def unqualify_columns(expression: exp.Expression) -> exp.Expression:
+ for column in expression.find_all(exp.Column):
+ # We only wanna pop off the table, db, catalog args
+ for part in column.parts[:-1]:
+ part.pop()
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: