summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/README.md34
-rw-r--r--sqlglot/dataframe/sql/column.py22
-rw-r--r--sqlglot/dataframe/sql/dataframe.py34
-rw-r--r--sqlglot/dataframe/sql/functions.py8
-rw-r--r--sqlglot/dataframe/sql/normalize.py3
-rw-r--r--sqlglot/dataframe/sql/readwriter.py23
-rw-r--r--sqlglot/dataframe/sql/session.py65
-rw-r--r--sqlglot/dataframe/sql/window.py4
-rw-r--r--sqlglot/dialects/bigquery.py86
-rw-r--r--sqlglot/dialects/clickhouse.py52
-rw-r--r--sqlglot/dialects/databricks.py15
-rw-r--r--sqlglot/dialects/dialect.py20
-rw-r--r--sqlglot/dialects/doris.py1
-rw-r--r--sqlglot/dialects/drill.py9
-rw-r--r--sqlglot/dialects/duckdb.py38
-rw-r--r--sqlglot/dialects/hive.py55
-rw-r--r--sqlglot/dialects/mysql.py32
-rw-r--r--sqlglot/dialects/oracle.py11
-rw-r--r--sqlglot/dialects/postgres.py38
-rw-r--r--sqlglot/dialects/presto.py54
-rw-r--r--sqlglot/dialects/redshift.py14
-rw-r--r--sqlglot/dialects/snowflake.py78
-rw-r--r--sqlglot/dialects/spark.py10
-rw-r--r--sqlglot/dialects/spark2.py31
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/teradata.py4
-rw-r--r--sqlglot/dialects/trino.py3
-rw-r--r--sqlglot/dialects/tsql.py157
-rw-r--r--sqlglot/expressions.py242
-rw-r--r--sqlglot/generator.py149
-rw-r--r--sqlglot/helper.py30
-rw-r--r--sqlglot/optimizer/__init__.py9
-rw-r--r--sqlglot/optimizer/annotate_types.py39
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py9
-rw-r--r--sqlglot/optimizer/optimize_joins.py7
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py14
-rw-r--r--sqlglot/optimizer/scope.py72
-rw-r--r--sqlglot/optimizer/simplify.py36
-rw-r--r--sqlglot/parser.py321
-rw-r--r--sqlglot/tokens.py45
-rw-r--r--sqlglot/transforms.py10
41 files changed, 1424 insertions, 465 deletions
diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md
index 86fdc4b..adde9a1 100644
--- a/sqlglot/dataframe/README.md
+++ b/sqlglot/dataframe/README.md
@@ -21,10 +21,12 @@ Currently many of the common operations are covered and more functionality will
* Ex: `['cola', 'colb']`
* The lack of types may limit functionality in future releases.
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally.
+* If your output SQL dialect is not Spark, then configure the SparkSession to use that dialect
+ * Ex: `SparkSession().builder.config("sqlframe.dialect", "bigquery").getOrCreate()`
+ * See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command.
- * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
- * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
- * Ex: `.sql(pretty=True, dialect='bigquery')`
+ * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
+ * Ex: `.sql(pretty=True)`
## Examples
@@ -33,6 +35,8 @@ import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F
+dialect = "spark"
+
sqlglot.schema.add_table(
'employee',
{
@@ -41,10 +45,10 @@ sqlglot.schema.add_table(
'lname': 'STRING',
'age': 'INT',
},
- dialect="spark",
+ dialect=dialect,
) # Register the table structure prior to reading from the table
-spark = SparkSession()
+spark = SparkSession.builder.config("sqlframe.dialect", dialect).getOrCreate()
df = (
spark
@@ -53,7 +57,7 @@ df = (
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
)
-print(df.sql(pretty=True)) # Spark will be the dialect used by default
+print(df.sql(pretty=True))
```
```sparksql
@@ -81,7 +85,7 @@ class ExternalSchema(Schema):
sqlglot.schema = ExternalSchema()
-spark = SparkSession()
+spark = SparkSession() # Spark will be used by default is not specific in SparkSession config
df = (
spark
@@ -119,11 +123,14 @@ schema = types.StructType([
])
sql_statements = (
- SparkSession()
+ SparkSession
+ .builder
+ .config("sqlframe.dialect", "bigquery")
+ .getOrCreate()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
- .sql(dialect="bigquery")
+ .sql()
)
result = None
@@ -166,11 +173,14 @@ schema = types.StructType([
])
sql_statements = (
- SparkSession()
+ SparkSession
+ .builder
+ .config("sqlframe.dialect", "snowflake")
+ .getOrCreate()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("lname")).alias("num_employees"))
- .sql(dialect="snowflake")
+ .sql()
)
try:
@@ -210,7 +220,7 @@ sql_statements = (
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
- .sql(dialect="spark")
+ .sql()
)
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index fcfd71e..3acf494 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -5,7 +5,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
-from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING:
@@ -15,19 +14,20 @@ if t.TYPE_CHECKING:
class Column:
def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ from sqlglot.dataframe.sql.session import SparkSession
+
if isinstance(expression, Column):
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
-
- expression = sqlglot.maybe_parse(expression, dialect="spark")
+ elif not isinstance(expression, exp.Column):
+ expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
+ SparkSession().dialect.normalize_identifier, copy=False
+ )
if expression is None:
raise ValueError(f"Could not parse {expression}")
- if isinstance(expression, exp.Column):
- expression.transform(Spark.normalize_identifier, copy=False)
-
- self.expression: exp.Expression = expression
+ self.expression: exp.Expression = expression # type: ignore
def __repr__(self):
return repr(self.expression)
@@ -207,7 +207,9 @@ class Column:
return Column(expression)
def sql(self, **kwargs) -> str:
- return self.expression.sql(**{"dialect": "spark", **kwargs})
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
def alias(self, name: str) -> Column:
new_expression = exp.alias_(self.column_expression, name)
@@ -264,9 +266,11 @@ class Column:
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
Sqlglot doesn't currently replicate this class so it only accepts a string
"""
+ from sqlglot.dataframe.sql.session import SparkSession
+
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
- return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
+ return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 64cceea..f515608 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -1,12 +1,13 @@
from __future__ import annotations
import functools
+import logging
import typing as t
import zlib
from copy import copy
import sqlglot
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.group import GroupedData
@@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
+from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@@ -27,7 +29,9 @@ if t.TYPE_CHECKING:
OutputExpressionContainer,
)
from sqlglot.dataframe.sql.session import SparkSession
+ from sqlglot.dialects.dialect import DialectType
+logger = logging.getLogger("sqlglot")
JOIN_HINTS = {
"BROADCAST",
@@ -264,7 +268,9 @@ class DataFrame:
@classmethod
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
- value = expression.sql(dialect="spark").encode("utf-8")
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
def _get_select_expressions(
@@ -291,7 +297,15 @@ class DataFrame:
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
- def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
+ def sql(
+ self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
+ ) -> t.List[str]:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
+ logger.warning(
+ f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
+ )
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
@@ -299,7 +313,10 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- select_expression = t.cast(exp.Select, optimize_func(select_expression))
+ quote_identifiers(select_expression)
+ select_expression = t.cast(
+ exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
+ )
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@@ -313,10 +330,12 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.sql("spark")
+ expression.alias_or_name: expression.type.sql(
+ dialect=SparkSession().dialect
+ )
for expression in select_expression.expressions
},
- dialect="spark",
+ dialect=SparkSession().dialect,
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
@@ -345,7 +364,8 @@ class DataFrame:
output_expressions.append(expression)
return [
- expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
+ expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
+ for expression in output_expressions
]
def copy(self, **kwargs) -> DataFrame:
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 4002cfe..d0ae50c 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -368,9 +368,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- if ignorenulls is not None:
- return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
- return Column.invoke_anonymous_function(col, "FIRST")
+ return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
def grouping_id(*cols: ColumnOrName) -> Column:
@@ -394,9 +392,7 @@ def isnull(col: ColumnOrName) -> Column:
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- if ignorenulls is not None:
- return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
- return Column.invoke_anonymous_function(col, "LAST")
+ return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
def monotonically_increasing_id() -> Column:
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index 4eec782..f68bacb 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -5,7 +5,6 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
-from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
@@ -20,7 +19,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
- Spark.normalize_identifier(identifier)
+ identifier.transform(spark.dialect.normalize_identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index 9d87d4a..0804486 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.dialects import Spark
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
@@ -18,15 +17,25 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
- sqlglot.schema.add_table(tableName, dialect="spark")
+ sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
return DataFrame(
self.spark,
exp.Select()
- .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
+ .from_(
+ exp.to_table(tableName, dialect=SparkSession().dialect).transform(
+ SparkSession().dialect.normalize_identifier
+ )
+ )
.select(
- *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
+ *(
+ column
+ for column in sqlglot.schema.column_names(
+ tableName, dialect=SparkSession().dialect
+ )
+ )
),
)
@@ -63,6 +72,8 @@ class DataFrameWriter:
return self.copy(by_name=True)
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ from sqlglot.dataframe.sql.session import SparkSession
+
output_expression_container = exp.Insert(
**{
"this": exp.to_table(tableName),
@@ -71,7 +82,9 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
- columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
+ columns = sqlglot.schema.column_names(
+ tableName, only_visible=True, dialect=SparkSession().dialect
+ )
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index b883359..531ee17 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -5,31 +5,35 @@ import uuid
from collections import defaultdict
import sqlglot
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
+from sqlglot.helper import classproperty
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
- known_ids: t.ClassVar[t.Set[str]] = set()
- known_branch_ids: t.ClassVar[t.Set[str]] = set()
- known_sequence_ids: t.ClassVar[t.Set[str]] = set()
- name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
+ DEFAULT_DIALECT = "spark"
+ _instance = None
def __init__(self):
- self.incrementing_id = 1
-
- def __getattr__(self, name: str) -> SparkSession:
- return self
-
- def __call__(self, *args, **kwargs) -> SparkSession:
- return self
+ if not hasattr(self, "known_ids"):
+ self.known_ids = set()
+ self.known_branch_ids = set()
+ self.known_sequence_ids = set()
+ self.name_to_sequence_id_mapping = defaultdict(list)
+ self.incrementing_id = 1
+ self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
+
+ def __new__(cls, *args, **kwargs) -> SparkSession:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
@property
def read(self) -> DataFrameReader:
@@ -101,7 +105,7 @@ class SparkSession:
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
- expression = sqlglot.parse_one(sqlQuery, read="spark")
+ expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
@@ -149,3 +153,38 @@ class SparkSession:
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)
+
+ class Builder:
+ SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
+
+ def __init__(self):
+ self.dialect = "spark"
+
+ def __getattr__(self, item) -> SparkSession.Builder:
+ return self
+
+ def __call__(self, *args, **kwargs):
+ return self
+
+ def config(
+ self,
+ key: t.Optional[str] = None,
+ value: t.Optional[t.Any] = None,
+ *,
+ map: t.Optional[t.Dict[str, t.Any]] = None,
+ **kwargs: t.Any,
+ ) -> SparkSession.Builder:
+ if key == self.SQLFRAME_DIALECT_KEY:
+ self.dialect = value
+ elif map and self.SQLFRAME_DIALECT_KEY in map:
+ self.dialect = map[self.SQLFRAME_DIALECT_KEY]
+ return self
+
+ def getOrCreate(self) -> SparkSession:
+ spark = SparkSession()
+ spark.dialect = Dialect.get_or_raise(self.dialect)()
+ return spark
+
+ @classproperty
+ def builder(cls) -> Builder:
+ return cls.Builder()
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
index c54c07e..c1d913f 100644
--- a/sqlglot/dataframe/sql/window.py
+++ b/sqlglot/dataframe/sql/window.py
@@ -48,7 +48,9 @@ class WindowSpec:
return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
- return self.expression.sql(dialect="spark", **kwargs)
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 71977dd..d763ed0 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
+ json_keyvalue_comma_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
@@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot")
def _date_add_sql(
data_type: str, kind: str
-) -> t.Callable[[generator.Generator, exp.Expression], str]:
- def func(self, expression):
+) -> t.Callable[[BigQuery.Generator, exp.Expression], str]:
+ def func(self: BigQuery.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
@@ -40,7 +41,7 @@ def _date_add_sql(
return func
-def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
+def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
@@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
-def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
+def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
@@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope
return f"RETURNS {this}"
-def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
+def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
@@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
- from sqlglot.optimizer.scope import Scope
+ from sqlglot.optimizer.scope import find_all_in_scope
if isinstance(expression, exp.Select):
- for unnest in expression.find_all(exp.Unnest):
- if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
- for column in Scope(expression).find_all(exp.Column):
- if column.table == unnest.alias:
- column.set("table", None)
+ unnest_aliases = {
+ unnest.alias
+ for unnest in find_all_in_scope(expression, exp.Unnest)
+ if isinstance(unnest.parent, (exp.From, exp.Join))
+ }
+ if unnest_aliases:
+ for column in expression.find_all(exp.Column):
+ if column.table in unnest_aliases:
+ column.set("table", None)
+ elif column.db in unnest_aliases:
+ column.set("db", None)
return expression
@@ -261,6 +268,7 @@ class BigQuery(Dialect):
"TIMESTAMP": TokenType.TIMESTAMPTZ,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
+ "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
KEYWORDS.pop("DIV")
@@ -270,6 +278,8 @@ class BigQuery(Dialect):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": _parse_date,
@@ -299,6 +309,8 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
+ "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
+ "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
"SPLIT": lambda args: exp.Split(
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
this=seq_get(args, 0),
@@ -346,7 +358,7 @@ class BigQuery(Dialect):
}
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
- this = super()._parse_table_part(schema=schema)
+ this = super()._parse_table_part(schema=schema) or self._parse_number()
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
if isinstance(this, exp.Identifier):
@@ -356,6 +368,17 @@ class BigQuery(Dialect):
table_name += f"-{self._prev.text}"
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
+ elif isinstance(this, exp.Literal):
+ table_name = this.name
+
+ if (
+ self._curr
+ and self._prev.end == self._curr.start - 1
+ and self._parse_var(any_token=True)
+ ):
+ table_name += self._prev.text
+
+ this = exp.Identifier(this=table_name, quoted=True)
return this
@@ -374,6 +397,27 @@ class BigQuery(Dialect):
return table
+ def _parse_json_object(self) -> exp.JSONObject:
+ json_object = super()._parse_json_object()
+ array_kv_pair = seq_get(json_object.expressions, 0)
+
+ # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2
+ if (
+ array_kv_pair
+ and isinstance(array_kv_pair.this, exp.Array)
+ and isinstance(array_kv_pair.expression, exp.Array)
+ ):
+ keys = array_kv_pair.this.expressions
+ values = array_kv_pair.expression.expressions
+
+ json_object.set(
+ "expressions",
+ [exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)],
+ )
+
+ return json_object
+
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
@@ -383,6 +427,7 @@ class BigQuery(Dialect):
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
ESCAPE_LINE_BREAK = True
+ NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -405,6 +450,7 @@ class BigQuery(Dialect):
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
+ exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
@@ -428,6 +474,9 @@ class BigQuery(Dialect):
_alias_ordered_group,
]
),
+ exp.SHA2: lambda self, e: self.func(
+ f"SHA256" if e.text("length") == "256" else "SHA512", e.this
+ ),
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@@ -591,6 +640,13 @@ class BigQuery(Dialect):
return super().attimezone_sql(expression)
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals
+ if expression.is_type("json"):
+ return f"JSON {self.sql(expression, 'this')}"
+
+ return super().cast_sql(expression, safe_prefix=safe_prefix)
+
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
@@ -630,3 +686,9 @@ class BigQuery(Dialect):
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("OPTIONS"))
+
+ def version_sql(self, expression: exp.Version) -> str:
+ if expression.name == "TIMESTAMP":
+ expression = expression.copy()
+ expression.set("this", "SYSTEM_TIME")
+ return super().version_sql(expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index cfde5fd..a38a239 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.errors import ParseError
+from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import Token, TokenType
@@ -63,9 +64,23 @@ class ClickHouse(Dialect):
}
class Parser(parser.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
+ "DATE_ADD": lambda args: exp.DateAdd(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ "DATEADD": lambda args: exp.DateAdd(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ "DATE_DIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ "DATEDIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
@@ -147,7 +162,7 @@ class ClickHouse(Dialect):
this = self._parse_id_var()
self._match(TokenType.COLON)
- kind = self._parse_types(check_func=False) or (
+ kind = self._parse_types(check_func=False, allow_identifiers=False) or (
self._match_text_seq("IDENTIFIER") and "Identifier"
)
@@ -249,7 +264,7 @@ class ClickHouse(Dialect):
def _parse_func_params(
self, this: t.Optional[exp.Func] = None
- ) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ ) -> t.Optional[t.List[exp.Expression]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)
@@ -267,9 +282,7 @@ class ClickHouse(Dialect):
return self.expression(exp.Quantile, this=params[0], quantile=this)
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
- def _parse_wrapped_id_vars(
- self, optional: bool = False
- ) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return super()._parse_wrapped_id_vars(optional=True)
def _parse_primary_key(
@@ -292,9 +305,22 @@ class ClickHouse(Dialect):
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
+ NVL2_SUPPORTED = False
+
+ STRING_TYPE_MAPPING = {
+ exp.DataType.Type.CHAR: "String",
+ exp.DataType.Type.LONGBLOB: "String",
+ exp.DataType.Type.LONGTEXT: "String",
+ exp.DataType.Type.MEDIUMBLOB: "String",
+ exp.DataType.Type.MEDIUMTEXT: "String",
+ exp.DataType.Type.TEXT: "String",
+ exp.DataType.Type.VARBINARY: "String",
+ exp.DataType.Type.VARCHAR: "String",
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ **STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
@@ -328,6 +354,12 @@ class ClickHouse(Dialect):
exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
+ exp.DateAdd: lambda self, e: self.func(
+ "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ ),
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ ),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -364,6 +396,16 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ # String is the standard ClickHouse type, every other variant is just an alias.
+ # Additionally, any supplied length parameter will be ignored.
+ #
+ # https://clickhouse.com/docs/en/sql-reference/data-types/string
+ if expression.this in self.STRING_TYPE_MAPPING:
+ return "String"
+
+ return super().datatype_sql(expression)
+
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
expression = expression.copy()
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 2149aca..6ec0487 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, transforms
-from sqlglot.dialects.dialect import parse_date_delta
+from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
from sqlglot.tokens import TokenType
@@ -28,6 +28,19 @@ class Databricks(Spark):
**Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
+ exp.DatetimeAdd: lambda self, e: self.func(
+ "TIMESTAMPADD", e.text("unit"), e.expression, e.this
+ ),
+ exp.DatetimeSub: lambda self, e: self.func(
+ "TIMESTAMPADD",
+ e.text("unit"),
+ exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
+ e.this,
+ ),
+ exp.DatetimeDiff: lambda self, e: self.func(
+ "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
+ ),
+ exp.DatetimeTrunc: timestamptrunc_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 132496f..1bfbfef 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -109,8 +109,7 @@ class _Dialect(type):
for k, v in vars(klass).items()
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
},
- "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
- "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
+ "TOKENIZER_CLASS": klass.tokenizer_class,
}
if enum not in ("", "bigquery"):
@@ -345,7 +344,7 @@ def arrow_json_extract_scalar_sql(
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
- return f"[{self.expressions(expression)}]"
+ return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
@@ -415,9 +414,9 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
- this = self.sql(expression, "this")
- struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
- return f"{this}.{struct_key}"
+ return (
+ f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
+ )
def var_map_sql(
@@ -722,3 +721,12 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
+
+
+def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
+ return self.func("MAX", expression.this)
+
+
+# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
+def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
+ return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 160c23c..4b8919c 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -37,7 +37,6 @@ class Doris(MySQL):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
- exp.Coalesce: rename_func("NVL"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 1b2681d..c811c86 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -16,8 +16,8 @@ from sqlglot.dialects.dialect import (
)
-def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -25,7 +25,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
-def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.DATE_FORMAT:
@@ -73,7 +73,6 @@ class Drill(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'"]
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
@@ -81,6 +80,7 @@ class Drill(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -95,6 +95,7 @@ class Drill(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 8253b52..684e35e 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
+ inline_array_sql,
no_comment_column_constraint_sql,
no_properties_sql,
no_safe_divide_sql,
@@ -30,13 +31,13 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
+def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
-def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
@@ -44,7 +45,7 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat
# BigQuery -> DuckDB conversion for the DATE function
-def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
+def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
result = f"CAST({self.sql(expression, 'this')} AS DATE)"
zone = self.sql(expression, "zone")
@@ -58,13 +59,13 @@ def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
return result
-def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
+def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
-def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
+def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
@@ -79,14 +80,14 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
+def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
@@ -97,7 +98,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
-def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
+def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
return f"CAST({sql} AS TEXT)"
@@ -134,6 +135,7 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
+ SUPPORTS_USER_DEFINED_TYPES = False
BITWISE = {
**parser.Parser.BITWISE,
@@ -183,18 +185,12 @@ class DuckDB(Dialect):
),
}
- TYPE_TOKENS = {
- *parser.Parser.TYPE_TOKENS,
- TokenType.UBIGINT,
- TokenType.UINT,
- TokenType.USMALLINT,
- TokenType.UTINYINT,
- }
-
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
- this = super()._parse_types(check_func=check_func, schema=schema)
+ this = super()._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
# See: https://duckdb.org/docs/sql/data_types/numeric
@@ -207,6 +203,9 @@ class DuckDB(Dialect):
return this
+ def _parse_struct_types(self) -> t.Optional[exp.Expression]:
+ return self._parse_field_def()
+
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
@@ -219,13 +218,14 @@ class DuckDB(Dialect):
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
+ NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if e.expressions and e.expressions[0].find(exp.Select)
- else rename_func("LIST_VALUE")(self, e),
+ else inline_array_sql(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 584acc6..8b17c06 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -50,7 +50,7 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
@@ -69,7 +69,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS
return self.func(func, expression.this, modified_increment)
-def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
+def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = TIME_DIFF_FACTOR.get(unit)
@@ -87,7 +87,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
-def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
+def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
# Since FROM_JSON requires a nested type, we always wrap the json string with
@@ -103,21 +103,21 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
return self.func("TO_JSON", this, expression.args.get("options"))
-def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
+def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
-def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
+def _property_sql(self: Hive.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
-def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
+def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
-def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@@ -125,7 +125,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
return f"CAST({this} AS DATE)"
-def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
+def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@@ -133,13 +133,13 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
return f"CAST({this} AS TIMESTAMP)"
-def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
+def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
-def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
+def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@@ -206,6 +206,8 @@ class Hive(Dialect):
"MSCK REPAIR": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
+ "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
+ "VERSION AS OF": TokenType.VERSION_SNAPSHOT,
}
NUMERIC_LITERALS = {
@@ -220,6 +222,7 @@ class Hive(Dialect):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -257,6 +260,11 @@ class Hive(Dialect):
),
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
+ "STR_TO_MAP": lambda args: exp.StrToMap(
+ this=seq_get(args, 0),
+ pair_delim=seq_get(args, 1) or exp.Literal.string(","),
+ key_value_delim=seq_get(args, 2) or exp.Literal.string(":"),
+ ),
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
@@ -313,7 +321,7 @@ class Hive(Dialect):
)
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
"""
Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
@@ -333,7 +341,9 @@ class Hive(Dialect):
Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
"""
- this = super()._parse_types(check_func=check_func, schema=schema)
+ this = super()._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
if this and not schema:
return this.transform(
@@ -345,6 +355,16 @@ class Hive(Dialect):
return this
+ def _parse_partition_and_order(
+ self,
+ ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
+ return (
+ self._parse_csv(self._parse_conjunction)
+ if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
+ else [],
+ super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
+ )
+
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
@@ -354,6 +374,7 @@ class Hive(Dialect):
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
+ NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -376,6 +397,7 @@ class Hive(Dialect):
]
),
exp.Property: _property_sql,
+ exp.AnyValue: rename_func("FIRST"),
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
@@ -402,6 +424,9 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
+ exp.NotNullColumnConstraint: lambda self, e: ""
+ if e.args.get("allow_null")
+ else "NOT NULL",
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
@@ -472,7 +497,7 @@ class Hive(Dialect):
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
elif expression.is_type("float"):
- size_expression = expression.find(exp.DataTypeSize)
+ size_expression = expression.find(exp.DataTypeParam)
if size_expression:
size = int(size_expression.name)
expression = (
@@ -480,3 +505,7 @@ class Hive(Dialect):
)
return super().datatype_sql(expression)
+
+ def version_sql(self, expression: exp.Version) -> str:
+ sql = super().version_sql(expression)
+ return sql.replace("FOR ", "", 1)
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 9ab4ce8..f9249eb 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
datestrtodate_sql,
format_time_lambda,
+ json_keyvalue_comma_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@@ -32,7 +33,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
return _parse
-def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
+def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
@@ -63,12 +64,12 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
+def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
-def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
+def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
@@ -83,8 +84,8 @@ def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -93,6 +94,9 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
+ # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
+ IDENTIFIERS_CAN_START_WITH_DIGIT = True
+
TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
@@ -129,6 +133,7 @@ class MySQL(Dialect):
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
+ "MEDIUMINT": TokenType.MEDIUMINT,
"MEMBER OF": TokenType.MEMBER_OF,
"SEPARATOR": TokenType.SEPARATOR,
"START": TokenType.BEGIN,
@@ -136,6 +141,7 @@ class MySQL(Dialect):
"SIGNED INTEGER": TokenType.BIGINT,
"UNSIGNED": TokenType.UBIGINT,
"UNSIGNED INTEGER": TokenType.UBIGINT,
+ "YEAR": TokenType.YEAR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -185,6 +191,8 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.DATABASE,
@@ -492,6 +500,17 @@ class MySQL(Dialect):
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
+ def _parse_type(self) -> t.Optional[exp.Expression]:
+ # mysql binary is special and can work anywhere, even in order by operations
+ # it operates like a no paren func
+ if self._match(TokenType.BINARY, advance=False):
+ data_type = self._parse_types(check_func=True, allow_identifiers=False)
+
+ if isinstance(data_type, exp.DataType):
+ return self.expression(exp.Cast, this=self._parse_column(), to=data_type)
+
+ return super()._parse_type()
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@@ -500,6 +519,7 @@ class MySQL(Dialect):
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
+ NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -515,6 +535,7 @@ class MySQL(Dialect):
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
@@ -524,6 +545,7 @@ class MySQL(Dialect):
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
+ exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 1f63e9f..279ed31 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
+def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
this = self._parse_string()
passing = None
@@ -22,7 +22,7 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
if self._match_text_seq("COLUMNS"):
- columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
+ columns = self._parse_csv(self._parse_field_def)
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
@@ -78,6 +78,10 @@ class Oracle(Dialect):
)
}
+ # SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT ..
+ # Reference: https://stackoverflow.com/a/336455
+ DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
+
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
@@ -129,7 +133,6 @@ class Oracle(Dialect):
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.ILike: no_ilike_sql,
- exp.Coalesce: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
@@ -162,7 +165,7 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
- VAR_SINGLE_TOKENS = {"@"}
+ VAR_SINGLE_TOKENS = {"@", "$", "#"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 73ca4e5..c26e121 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
@@ -39,8 +40,8 @@ DATE_DIFF_FACTOR = {
}
-def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
expression = expression.copy()
this = self.sql(expression, "this")
@@ -56,7 +57,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
-def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
+def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
@@ -82,7 +83,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"CAST({unit} AS BIGINT)"
-def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
+def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
@@ -93,7 +94,7 @@ def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
return f"SUBSTRING({this}{from_part}{for_part})"
-def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@@ -107,7 +108,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
return f"STRING_AGG({self.format_args(this, separator)}{order})"
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
@@ -254,6 +255,7 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
+ "@@": TokenType.DAT,
"@>": TokenType.AT_GT,
"<@": TokenType.LT_AT,
"BEGIN": TokenType.COMMAND,
@@ -273,6 +275,18 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
+ "OID": TokenType.OBJECT_IDENTIFIER,
+ "REGCLASS": TokenType.OBJECT_IDENTIFIER,
+ "REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
+ "REGCONFIG": TokenType.OBJECT_IDENTIFIER,
+ "REGDICTIONARY": TokenType.OBJECT_IDENTIFIER,
+ "REGNAMESPACE": TokenType.OBJECT_IDENTIFIER,
+ "REGOPER": TokenType.OBJECT_IDENTIFIER,
+ "REGOPERATOR": TokenType.OBJECT_IDENTIFIER,
+ "REGPROC": TokenType.OBJECT_IDENTIFIER,
+ "REGPROCEDURE": TokenType.OBJECT_IDENTIFIER,
+ "REGROLE": TokenType.OBJECT_IDENTIFIER,
+ "REGTYPE": TokenType.OBJECT_IDENTIFIER,
}
SINGLE_TOKENS = {
@@ -312,6 +326,9 @@ class Postgres(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
+ TokenType.DAT: lambda self, this: self.expression(
+ exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
+ ),
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
@@ -343,6 +360,7 @@ class Postgres(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@@ -357,6 +375,8 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
+ exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.Explode: rename_func("UNNEST"),
@@ -416,3 +436,9 @@ class Postgres(Dialect):
expression.set("this", exp.paren(expression.this, copy=False))
return super().bracket_sql(expression)
+
+ def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
+ this = self.sql(expression, "this")
+ expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions]
+ sql = " OR ".join(expressions)
+ return f"({sql})" if len(expressions) > 1 else sql
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 078da0b..4b54e95 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
-def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
+def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
+def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
return self.sql(
@@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -
return self.lateral_sql(expression)
-def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
+def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
-def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
+def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
return self.func("ARRAY_SORT", expression.this, comparator)
-def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
+def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
@@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
return self.schema_sql(expression)
-def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
+def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(
- self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+ self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
-def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
+def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
-def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
+def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
@@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
return expression
+def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
+ """
+ Trino doesn't support FIRST / LAST as functions, but they're valid in the context
+ of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
+ they're converted into an ARBITRARY call.
+
+ Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions
+ """
+ if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize):
+ return self.function_fallback_sql(expression)
+
+ return rename_func("ARBITRARY")(self, expression)
+
+
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@@ -178,6 +192,7 @@ class Presto(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "ARBITRARY": exp.AnyValue.from_arg_list,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
@@ -205,7 +220,14 @@ class Presto(Dialect):
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
+ "REGEXP_REPLACE": lambda args: exp.RegexpReplace(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ replacement=seq_get(args, 2) or exp.Literal.string(""),
+ ),
+ "ROW": exp.Struct.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
+ "SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
@@ -225,6 +247,7 @@ class Presto(Dialect):
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
TZ_TO_WITH_TIME_ZONE = True
+ NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@@ -242,10 +265,13 @@ class Presto(Dialect):
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
+ exp.DataType.Type.DATETIME64: "TIMESTAMP",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: rename_func("ARBITRARY"),
exp.ApproxDistinct: _approx_distinct_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@@ -268,15 +294,23 @@ class Presto(Dialect):
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
+ exp.DateSub: lambda self, e: self.func(
+ "DATE_ADD",
+ exp.Literal.string(e.text("unit") or "day"),
+ e.expression * -1,
+ e.this,
+ ),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
+ exp.First: _first_last_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
+ exp.Last: _first_last_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@@ -301,8 +335,10 @@ class Presto(Dialect):
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
+ exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.Struct: rename_func("ROW"),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 30731e1..351c5df 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -13,7 +13,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
+def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
@@ -37,6 +37,8 @@ class Redshift(Postgres):
}
class Parser(Postgres.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"ADD_MONTHS": lambda args: exp.DateAdd(
@@ -55,9 +57,11 @@ class Redshift(Postgres):
}
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
- this = super()._parse_types(check_func=check_func, schema=schema)
+ this = super()._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
if (
isinstance(this, exp.DataType)
@@ -100,6 +104,7 @@ class Redshift(Postgres):
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
+ NVL2_SUPPORTED = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -142,6 +147,9 @@ class Redshift(Postgres):
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
+ # Redshift supports ANY_VALUE(..)
+ TRANSFORMS.pop(exp.AnyValue)
+
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def with_properties(self, properties: exp.Properties) -> str:
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 9733a85..8d8183c 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
+def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) ->
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
+def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
if not this:
@@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return "ARRAY"
elif expression.is_type("map"):
@@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
+def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
+ flag = expression.text("flag")
+
+ if "i" not in flag:
+ flag += "i"
+
+ return self.func(
+ "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag)
+ )
+
+
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
@@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
return regexp_replace
+def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]:
+ def _parse(self: Snowflake.Parser) -> exp.Show:
+ return self._parse_show_snowflake(*args, **kwargs)
+
+ return _parse
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
@@ -216,6 +234,7 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -230,6 +249,7 @@ class Snowflake(Dialect):
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
+ "LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"REGEXP_REPLACE": _parse_regexp_replace,
@@ -250,11 +270,6 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS.pop("TRIM")
- FUNC_TOKENS = {
- *parser.Parser.FUNC_TOKENS,
- TokenType.TABLE,
- }
-
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@@ -281,6 +296,16 @@ class Snowflake(Dialect):
),
}
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.SHOW: lambda self: self._parse_show(),
+ }
+
+ SHOW_PARSERS = {
+ "PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ }
+
def _parse_id_var(
self,
any_token: bool = True,
@@ -296,8 +321,24 @@ class Snowflake(Dialect):
return super()._parse_id_var(any_token=any_token, tokens=tokens)
+ def _parse_show_snowflake(self, this: str) -> exp.Show:
+ scope = None
+ scope_kind = None
+
+ if self._match(TokenType.IN):
+ if self._match_text_seq("ACCOUNT"):
+ scope_kind = "ACCOUNT"
+ elif self._match_set(self.DB_CREATABLES):
+ scope_kind = self._prev.text
+ if self._curr:
+ scope = self._parse_table()
+ elif self._curr:
+ scope_kind = "TABLE"
+ scope = self._parse_table()
+
+ return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
+
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
RAW_STRINGS = ["$$"]
@@ -331,6 +372,8 @@ class Snowflake(Dialect):
VAR_SINGLE_TOKENS = {"$"}
+ COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
+
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
@@ -355,6 +398,7 @@ class Snowflake(Dialect):
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Extract: rename_func("DATE_PART"),
+ exp.GroupConcat: rename_func("LISTAGG"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@@ -362,6 +406,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.RegexpILike: _regexpilike_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
@@ -373,6 +418,7 @@ class Snowflake(Dialect):
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
),
+ exp.Stuff: rename_func("INSERT"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@@ -403,6 +449,16 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def show_sql(self, expression: exp.Show) -> str:
+ scope = self.sql(expression, "scope")
+ scope = f" {scope}" if scope else ""
+
+ scope_kind = self.sql(expression, "scope_kind")
+ if scope_kind:
+ scope_kind = f" IN {scope_kind}"
+
+ return f"SHOW {expression.name}{scope_kind}{scope}"
+
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
# generate default values as necessary to ensure the transpilation is correct
@@ -436,7 +492,9 @@ class Snowflake(Dialect):
kind_value = expression.args.get("kind") or "TABLE"
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
- return f"DESCRIBE{kind}{this}"
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" {expressions}" if expressions else ""
+ return f"DESCRIBE{kind}{this}{expressions}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 7c8982b..a4435f6 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -38,9 +38,15 @@ class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
+ "ANY_VALUE": lambda args: exp.AnyValue(
+ this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
+ ),
"DATEDIFF": _parse_datediff,
}
+ FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
+ FUNCTION_PARSERS.pop("ANY_VALUE")
+
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
@@ -56,9 +62,13 @@ class Spark(Spark2):
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
}
+ TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
+ def anyvalue_sql(self, expression: exp.AnyValue) -> str:
+ return self.function_fallback_sql(expression)
+
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index ceb48f8..4489b6b 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -15,7 +15,7 @@ from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
+def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
@@ -31,17 +31,21 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
return create_with_partitions_sql(self, e)
-def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
- keys = self.sql(expression.args["keys"])
- values = self.sql(expression.args["values"])
- return f"MAP_FROM_ARRAYS({keys}, {values})"
+def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
+ keys = expression.args.get("keys")
+ values = expression.args.get("values")
+
+ if not keys or not values:
+ return "MAP()"
+
+ return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
-def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.DATE_FORMAT:
@@ -49,7 +53,7 @@ def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
return f"TO_DATE({this}, {time_format})"
-def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
+def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@@ -110,6 +114,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str:
+ if expression.expression.args.get("with"):
+ expression = expression.copy()
+ expression.set("with", expression.expression.args.pop("with"))
+ return self.insert_sql(expression)
+
+
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
@@ -169,10 +180,7 @@ class Spark2(Hive):
class Generator(Hive.Generator):
QUERY_HINTS = True
-
- TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING,
- }
+ NVL2_SUPPORTED = True
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
@@ -197,6 +205,7 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
+ exp.Insert: _insert_sql,
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 90b774e..7bfdf1c 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
@@ -18,7 +19,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
-def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
+def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
modifier = expression.expression
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
@@ -78,6 +79,7 @@ class SQLite(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -103,6 +105,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
exp.Concat: concat_to_dpipe_sql,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 2be1a62..163cc13 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -95,6 +95,9 @@ class Teradata(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
+ TokenType.DATABASE: lambda self: self.expression(
+ exp.Use, this=self._parse_table(schema=False)
+ ),
TokenType.REPLACE: lambda self: self._parse_create(),
}
@@ -165,6 +168,7 @@ class Teradata(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index af0f78d..0c953a1 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -13,3 +13,6 @@ class Trino(Presto):
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
+
+ class Parser(Presto.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 131307f..b26f499 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -7,6 +7,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
max_or_greatest,
min_or_least,
parse_date_delta,
@@ -79,22 +80,23 @@ def _format_time_lambda(
def _parse_format(args: t.List) -> exp.Expression:
- assert len(args) == 2
+ this = seq_get(args, 0)
+ fmt = seq_get(args, 1)
+ culture = seq_get(args, 2)
- fmt = args[1]
- number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
+ number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name))
if number_fmt:
- return exp.NumberToStr(this=args[0], format=fmt)
+ return exp.NumberToStr(this=this, format=fmt, culture=culture)
- return exp.TimeToStr(
- this=args[0],
- format=exp.Literal.string(
+ if fmt:
+ fmt = exp.Literal.string(
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
if len(fmt.name) == 1
else format_time(fmt.name, TSQL.TIME_MAPPING)
- ),
- )
+ )
+
+ return exp.TimeToStr(this=this, format=fmt, culture=culture)
def _parse_eomonth(args: t.List) -> exp.Expression:
@@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
def generate_date_delta_with_unit_sql(
- self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
+ self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
) -> str:
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
return self.func(func, expression.text("unit"), expression.expression, expression.this)
-def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
+def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
expression.args["format"]
if isinstance(expression, exp.NumberToStr)
@@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
)
)
)
- return self.func("FORMAT", expression.this, fmt)
+ return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
-def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
this = expression.this
@@ -332,10 +334,12 @@ class TSQL(Dialect):
"SQL_VARIANT": TokenType.VARIANT,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
+ "UPDATE STATISTICS": TokenType.COMMAND,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
+ "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
class Parser(parser.Parser):
@@ -395,7 +399,9 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
- def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+
+ def _parse_projections(self) -> t.List[exp.Expression]:
"""
T-SQL supports the syntax alias = expression in the SELECT's projection list,
so we transform all parsed Selects to convert their EQ projections into Aliases.
@@ -458,43 +464,6 @@ class TSQL(Dialect):
return self._parse_as_command(self._prev)
- def _parse_system_time(self) -> t.Optional[exp.Expression]:
- if not self._match_text_seq("FOR", "SYSTEM_TIME"):
- return None
-
- if self._match_text_seq("AS", "OF"):
- system_time = self.expression(
- exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
- )
- elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
- kind = self._prev.text
- this = self._parse_bitwise()
- self._match_texts(("TO", "AND"))
- expression = self._parse_bitwise()
- system_time = self.expression(
- exp.SystemTime, this=this, expression=expression, kind=kind
- )
- elif self._match_text_seq("CONTAINED", "IN"):
- args = self._parse_wrapped_csv(self._parse_bitwise)
- system_time = self.expression(
- exp.SystemTime,
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- kind="CONTAINED IN",
- )
- elif self._match(TokenType.ALL):
- system_time = self.expression(exp.SystemTime, kind="ALL")
- else:
- system_time = None
- self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
-
- return system_time
-
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
- table = super()._parse_table_parts(schema=schema)
- table.set("system_time", self._parse_system_time())
- return table
-
def _parse_returns(self) -> exp.ReturnsProperty:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
@@ -589,14 +558,36 @@ class TSQL(Dialect):
return create
+ def _parse_if(self) -> t.Optional[exp.Expression]:
+ index = self._index
+
+ if self._match_text_seq("OBJECT_ID"):
+ self._parse_wrapped_csv(self._parse_string)
+ if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP):
+ return self._parse_drop(exists=True)
+ self._retreat(index)
+
+ return super()._parse_if()
+
+ def _parse_unique(self) -> exp.UniqueColumnConstraint:
+ return self.expression(
+ exp.UniqueColumnConstraint,
+ this=None
+ if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
+ else self._parse_schema(self._parse_id_var(any_token=False)),
+ )
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
+ NVL2_SUPPORTED = False
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.INT: "INTEGER",
@@ -607,6 +598,8 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
+ exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@@ -651,25 +644,44 @@ class TSQL(Dialect):
return sql
- def offset_sql(self, expression: exp.Offset) -> str:
- return f"{super().offset_sql(expression)} ROWS"
+ def create_sql(self, expression: exp.Create) -> str:
+ expression = expression.copy()
+ kind = self.sql(expression, "kind").upper()
+ exists = expression.args.pop("exists", None)
+ sql = super().create_sql(expression)
+
+ if exists:
+ table = expression.find(exp.Table)
+ identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
+ if kind == "SCHEMA":
+ sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')"""
+ elif kind == "TABLE":
+ sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')"""
+ elif kind == "INDEX":
+ index = self.sql(exp.Literal.string(expression.this.text("this")))
+ sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')"""
+ elif expression.args.get("replace"):
+ sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
- def systemtime_sql(self, expression: exp.SystemTime) -> str:
- kind = expression.args["kind"]
- if kind == "ALL":
- return "FOR SYSTEM_TIME ALL"
+ return sql
- start = self.sql(expression, "this")
- if kind == "AS OF":
- return f"FOR SYSTEM_TIME AS OF {start}"
+ def offset_sql(self, expression: exp.Offset) -> str:
+ return f"{super().offset_sql(expression)} ROWS"
- end = self.sql(expression, "expression")
- if kind == "FROM":
- return f"FOR SYSTEM_TIME FROM {start} TO {end}"
- if kind == "BETWEEN":
- return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
+ def version_sql(self, expression: exp.Version) -> str:
+ name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name
+ this = f"FOR {name}"
+ expr = expression.expression
+ kind = expression.text("kind")
+ if kind in ("FROM", "BETWEEN"):
+ args = expr.expressions
+ sep = "TO" if kind == "FROM" else "AND"
+ expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}"
+ else:
+ expr_sql = self.sql(expr)
- return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
+ expr_sql = f" {expr_sql}" if expr_sql else ""
+ return f"{this} {kind}{expr_sql}"
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
table = expression.args.get("table")
@@ -713,3 +725,16 @@ class TSQL(Dialect):
identifier = f"#{identifier}"
return identifier
+
+ def constraint_sql(self, expression: exp.Constraint) -> str:
+ this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True, sep=" ")
+ return f"CONSTRAINT {this} {expressions}"
+
+ # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
+ start = self.sql(expression, "start") or "1"
+ increment = self.sql(expression, "increment") or "1"
+ return f"IDENTITY({start}, {increment})"
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 57b8bfa..0479da0 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1035,12 +1035,13 @@ class Clone(Expression):
"this": True,
"when": False,
"kind": False,
+ "shallow": False,
"expression": False,
}
class Describe(Expression):
- arg_types = {"this": True, "kind": False}
+ arg_types = {"this": True, "kind": False, "expressions": False}
class Pragma(Expression):
@@ -1070,6 +1071,8 @@ class Show(Expression):
"like": False,
"where": False,
"db": False,
+ "scope": False,
+ "scope_kind": False,
"full": False,
"mutex": False,
"query": False,
@@ -1207,6 +1210,10 @@ class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
+class Comprehension(Expression):
+ arg_types = {"this": True, "expression": True, "iterator": True, "condition": False}
+
+
# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
class MergeTreeTTLAction(Expression):
arg_types = {
@@ -1269,6 +1276,10 @@ class CheckColumnConstraint(ColumnConstraintKind):
pass
+class ClusteredColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class CollateColumnConstraint(ColumnConstraintKind):
pass
@@ -1316,6 +1327,14 @@ class InlineLengthColumnConstraint(ColumnConstraintKind):
pass
+class NonClusteredColumnConstraint(ColumnConstraintKind):
+ pass
+
+
+class NotForReplicationColumnConstraint(ColumnConstraintKind):
+ arg_types = {}
+
+
class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
@@ -1345,6 +1364,12 @@ class PathColumnConstraint(ColumnConstraintKind):
pass
+# computed column expression
+# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16
+class ComputedColumnConstraint(ColumnConstraintKind):
+ arg_types = {"this": True, "persisted": False, "not_null": False}
+
+
class Constraint(Expression):
arg_types = {"this": True, "expressions": True}
@@ -1489,6 +1514,15 @@ class Check(Expression):
pass
+# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
+class Connect(Expression):
+ arg_types = {"start": False, "connect": True}
+
+
+class Prior(Expression):
+ pass
+
+
class Directory(Expression):
# https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html
arg_types = {"this": True, "local": False, "row_format": False}
@@ -1578,6 +1612,7 @@ class Insert(DDL):
"alternative": False,
"where": False,
"ignore": False,
+ "by_name": False,
}
def with_(
@@ -2045,8 +2080,12 @@ class NoPrimaryIndexProperty(Property):
arg_types = {}
+class OnProperty(Property):
+ arg_types = {"this": True}
+
+
class OnCommitProperty(Property):
- arg_type = {"delete": False}
+ arg_types = {"delete": False}
class PartitionedByProperty(Property):
@@ -2282,6 +2321,16 @@ class Subqueryable(Unionable):
def named_selects(self) -> t.List[str]:
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
+ def select(
+ self,
+ *expressions: t.Optional[ExpOrStr],
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Subqueryable:
+ raise NotImplementedError("Subqueryable objects must implement `select`")
+
def with_(
self,
alias: ExpOrStr,
@@ -2323,6 +2372,7 @@ QUERY_MODIFIERS = {
"match": False,
"laterals": False,
"joins": False,
+ "connect": False,
"pivots": False,
"where": False,
"group": False,
@@ -2363,6 +2413,7 @@ class Table(Expression):
"pivots": False,
"hints": False,
"system_time": False,
+ "version": False,
}
@property
@@ -2403,21 +2454,13 @@ class Table(Expression):
return parts
-# See the TSQL "Querying data in a system-versioned temporal table" page
-class SystemTime(Expression):
- arg_types = {
- "this": False,
- "expression": False,
- "kind": True,
- }
-
-
class Union(Subqueryable):
arg_types = {
"with": False,
"this": True,
"expression": True,
"distinct": False,
+ "by_name": False,
**QUERY_MODIFIERS,
}
@@ -2529,6 +2572,7 @@ class Update(Expression):
"from": False,
"where": False,
"returning": False,
+ "order": False,
"limit": False,
}
@@ -2545,6 +2589,20 @@ class Var(Expression):
pass
+class Version(Expression):
+ """
+ Time travel, iceberg, bigquery etc
+ https://trino.io/docs/current/connector/iceberg.html?highlight=snapshot#using-snapshots
+ https://www.databricks.com/blog/2019/02/04/introducing-delta-time-travel-for-large-scale-data-lakes.html
+ https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#for_system_time_as_of
+ https://learn.microsoft.com/en-us/sql/relational-databases/tables/querying-data-in-a-system-versioned-temporal-table?view=sql-server-ver16
+ this is either TIMESTAMP or VERSION
+ kind is ("AS OF", "BETWEEN")
+ """
+
+ arg_types = {"this": True, "kind": True, "expression": False}
+
+
class Schema(Expression):
arg_types = {"this": False, "expressions": False}
@@ -3263,6 +3321,23 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this
return expression
+ def unwrap(self) -> Subquery:
+ expression = self
+ while expression.same_parent and expression.is_wrapper:
+ expression = t.cast(Subquery, expression.parent)
+ return expression
+
+ @property
+ def is_wrapper(self) -> bool:
+ """
+ Whether this Subquery acts as a simple wrapper around another expression.
+
+ SELECT * FROM (((SELECT * FROM t)))
+ ^
+ This corresponds to a "wrapper" Subquery node
+ """
+ return all(v is None for k, v in self.args.items() if k != "this")
+
@property
def is_star(self) -> bool:
return self.this.is_star
@@ -3313,7 +3388,7 @@ class Pivot(Expression):
}
-class Window(Expression):
+class Window(Condition):
arg_types = {
"this": True,
"partition_by": False,
@@ -3375,7 +3450,7 @@ class Boolean(Condition):
pass
-class DataTypeSize(Expression):
+class DataTypeParam(Expression):
arg_types = {"this": True, "expression": False}
@@ -3386,6 +3461,7 @@ class DataType(Expression):
"nested": False,
"values": False,
"prefix": False,
+ "kind": False,
}
class Type(AutoName):
@@ -3432,6 +3508,7 @@ class DataType(Expression):
LOWCARDINALITY = auto()
MAP = auto()
MEDIUMBLOB = auto()
+ MEDIUMINT = auto()
MEDIUMTEXT = auto()
MONEY = auto()
NCHAR = auto()
@@ -3475,6 +3552,7 @@ class DataType(Expression):
VARCHAR = auto()
VARIANT = auto()
XML = auto()
+ YEAR = auto()
TEXT_TYPES = {
Type.CHAR,
@@ -3498,7 +3576,10 @@ class DataType(Expression):
Type.DOUBLE,
}
- NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
+ NUMERIC_TYPES = {
+ *INTEGER_TYPES,
+ *FLOAT_TYPES,
+ }
TEMPORAL_TYPES = {
Type.TIME,
@@ -3511,23 +3592,39 @@ class DataType(Expression):
Type.DATETIME64,
}
- META_TYPES = {"UNKNOWN", "NULL"}
-
@classmethod
def build(
- cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
+ cls,
+ dtype: str | DataType | DataType.Type,
+ dialect: DialectType = None,
+ udt: bool = False,
+ **kwargs,
) -> DataType:
+ """
+ Constructs a DataType object.
+
+ Args:
+ dtype: the data type of interest.
+ dialect: the dialect to use for parsing `dtype`, in case it's a string.
+ udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
+ DataType, thus creating a user-defined type.
+ kawrgs: additional arguments to pass in the constructor of DataType.
+
+ Returns:
+ The constructed DataType object.
+ """
from sqlglot import parse_one
if isinstance(dtype, str):
- upper = dtype.upper()
- if upper in DataType.META_TYPES:
- data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
- else:
- data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+ if dtype.upper() == "UNKNOWN":
+ return DataType(this=DataType.Type.UNKNOWN, **kwargs)
- if data_type_exp is None:
- raise ValueError(f"Unparsable data type value: {dtype}")
+ try:
+ data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+ except ParseError:
+ if udt:
+ return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
+ raise
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
@@ -3538,7 +3635,31 @@ class DataType(Expression):
return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
- return any(self.this == DataType.build(dtype).this for dtype in dtypes)
+ """
+ Checks whether this DataType matches one of the provided data types. Nested types or precision
+ will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
+
+ Args:
+ dtypes: the data types to compare this DataType to.
+
+ Returns:
+ True, if and only if there is a type in `dtypes` which is equal to this DataType.
+ """
+ for dtype in dtypes:
+ other = DataType.build(dtype, udt=True)
+
+ if (
+ other.expressions
+ or self.this == DataType.Type.USERDEFINED
+ or other.this == DataType.Type.USERDEFINED
+ ):
+ matches = self == other
+ else:
+ matches = self.this == other.this
+
+ if matches:
+ return True
+ return False
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@@ -3546,6 +3667,11 @@ class PseudoType(Expression):
pass
+# https://www.postgresql.org/docs/15/datatype-oid.html
+class ObjectIdentifier(Expression):
+ pass
+
+
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
class SubqueryPredicate(Predicate):
pass
@@ -4005,6 +4131,7 @@ class ArrayAny(Func):
class ArrayConcat(Func):
+ _sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"]
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -4047,7 +4174,15 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
- arg_types = {"this": True, "having": False, "max": False}
+ arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
+
+
+class First(Func):
+ arg_types = {"this": True, "ignore_nulls": False}
+
+
+class Last(Func):
+ arg_types = {"this": True, "ignore_nulls": False}
class Case(Func):
@@ -4086,18 +4221,29 @@ class Cast(Func):
return self.name
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
- return self.to.is_type(*dtypes)
+ """
+ Checks whether this Cast's DataType matches one of the provided data types. Nested types
+ like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
+ array<int> != array<float>.
+ Args:
+ dtypes: the data types to compare this Cast's DataType to.
-class CastToStrType(Func):
- arg_types = {"this": True, "expression": True}
+ Returns:
+ True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType.
+ """
+ return self.to.is_type(*dtypes)
-class Collate(Binary):
+class TryCast(Cast):
pass
-class TryCast(Cast):
+class CastToStrType(Func):
+ arg_types = {"this": True, "to": True}
+
+
+class Collate(Binary):
pass
@@ -4310,7 +4456,7 @@ class Greatest(Func):
is_var_len_args = True
-class GroupConcat(Func):
+class GroupConcat(AggFunc):
arg_types = {"this": True, "separator": False}
@@ -4648,8 +4794,19 @@ class StrToUnix(Func):
arg_types = {"this": False, "format": False}
+# https://prestodb.io/docs/current/functions/string.html
+# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map
+class StrToMap(Func):
+ arg_types = {
+ "this": True,
+ "pair_delim": False,
+ "key_value_delim": False,
+ "duplicate_resolution_callback": False,
+ }
+
+
class NumberToStr(Func):
- arg_types = {"this": True, "format": True}
+ arg_types = {"this": True, "format": True, "culture": False}
class FromBase(Func):
@@ -4665,6 +4822,13 @@ class StructExtract(Func):
arg_types = {"this": True, "expression": True}
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/stuff-transact-sql?view=sql-server-ver16
+# https://docs.snowflake.com/en/sql-reference/functions/insert
+class Stuff(Func):
+ _sql_names = ["STUFF", "INSERT"]
+ arg_types = {"this": True, "start": True, "length": True, "expression": True}
+
+
class Sum(AggFunc):
pass
@@ -4686,7 +4850,7 @@ class StddevSamp(AggFunc):
class TimeToStr(Func):
- arg_types = {"this": True, "format": True}
+ arg_types = {"this": True, "format": True, "culture": False}
class TimeToTimeStr(Func):
@@ -5724,9 +5888,9 @@ def table_(
The new Table instance.
"""
return Table(
- this=to_identifier(table, quoted=quoted),
- db=to_identifier(db, quoted=quoted),
- catalog=to_identifier(catalog, quoted=quoted),
+ this=to_identifier(table, quoted=quoted) if table else None,
+ db=to_identifier(db, quoted=quoted) if db else None,
+ catalog=to_identifier(catalog, quoted=quoted) if catalog else None,
alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
@@ -5844,8 +6008,8 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
- keys=[convert(k, copy=copy) for k in value],
- values=[convert(v, copy=copy) for v in value.values()],
+ keys=Array(expressions=[convert(k, copy=copy) for k in value]),
+ values=Array(expressions=[convert(v, copy=copy) for v in value.values()]),
)
raise ValueError(f"Cannot convert {value}")
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index f8d7d68..306df81 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -8,7 +8,7 @@ from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
-from sqlglot.tokens import TokenType
+from sqlglot.tokens import Tokenizer, TokenType
logger = logging.getLogger("sqlglot")
@@ -61,6 +61,7 @@ class Generator:
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
+ exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
@@ -78,7 +79,10 @@ class Generator:
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
+ exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
+ exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
+ exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@@ -171,6 +175,9 @@ class Generator:
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False
+ # Whether or not the NVL2 function is supported
+ NVL2_SUPPORTED = True
+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@@ -179,6 +186,9 @@ class Generator:
# SELECT * VALUES into SELECT UNION
VALUES_AS_TABLE = True
+ # Whether or not the word COLUMN is included when adding a column with ALTER TABLE
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -245,6 +255,7 @@ class Generator:
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
+ exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
@@ -317,8 +328,7 @@ class Generator:
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
- STRING_ESCAPE = "'"
- IDENTIFIER_ESCAPE = '"'
+ TOKENIZER_CLASS = Tokenizer
# Delimiters for bit, hex, byte and raw literals
BIT_START: t.Optional[str] = None
@@ -379,8 +389,10 @@ class Generator:
)
self.unsupported_messages: t.List[str] = []
- self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
- self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
+ self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
+ self._escaped_identifier_end: str = (
+ self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
+ )
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
@@ -626,6 +638,16 @@ class Generator:
kind_sql = self.sql(expression, "kind").strip()
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
+ def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
+ this = self.sql(expression, "this")
+ if expression.args.get("not_null"):
+ persisted = " PERSISTED NOT NULL"
+ elif expression.args.get("persisted"):
+ persisted = " PERSISTED"
+ else:
+ persisted = ""
+ return f"AS {this}{persisted}"
+
def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
@@ -642,8 +664,8 @@ class Generator:
) -> str:
this = ""
if expression.this is not None:
- on_null = "ON NULL " if expression.args.get("on_null") else ""
- this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}"
+ on_null = " ON NULL" if expression.args.get("on_null") else ""
+ this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
@@ -668,7 +690,7 @@ class Generator:
expr = self.sql(expression, "expression")
expr = f"({expr})" if expr else "IDENTITY"
- return f"GENERATED{this}AS {expr}{sequence_opts}"
+ return f"GENERATED{this} AS {expr}{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@@ -774,14 +796,16 @@ class Generator:
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
+ shallow = "SHALLOW " if expression.args.get("shallow") else ""
+ this = f"{shallow}CLONE {this}"
when = self.sql(expression, "when")
if when:
kind = self.sql(expression, "kind")
expr = self.sql(expression, "expression")
- return f"CLONE {this} {when} ({kind} => {expr})"
+ return f"{this} {when} ({kind} => {expr})"
- return f"CLONE {this}"
+ return this
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
@@ -830,7 +854,7 @@ class Generator:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
- def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
+ def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
this = self.sql(expression, "this")
specifier = self.sql(expression, "expression")
specifier = f" {specifier}" if specifier else ""
@@ -839,11 +863,14 @@ class Generator:
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
- type_sql = (
- self.TYPE_MAPPING.get(type_value, type_value.value)
- if isinstance(type_value, exp.DataType.Type)
- else type_value
- )
+ if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"):
+ type_sql = self.sql(expression, "kind")
+ else:
+ type_sql = (
+ self.TYPE_MAPPING.get(type_value, type_value.value)
+ if isinstance(type_value, exp.DataType.Type)
+ else type_value
+ )
nested = ""
interior = self.expressions(expression, flat=True)
@@ -943,9 +970,9 @@ class Generator:
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
- table = f"{self.INDEX_ON} {table} " if table else ""
+ table = f"{self.INDEX_ON} {table}" if table else ""
using = self.sql(expression, "using")
- using = f"USING {using} " if using else ""
+ using = f" USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
@@ -1171,6 +1198,7 @@ class Generator:
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
+ by_name = " BY NAME" if expression.args.get("by_name") else ""
returning = self.sql(expression, "returning")
if self.RETURNING_END:
@@ -1178,7 +1206,7 @@ class Generator:
else:
expression_sql = f"{returning}{expression_sql}{conflict}"
- sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}"
+ sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@@ -1196,6 +1224,9 @@ class Generator:
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
+ def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
+ return expression.name.upper()
+
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
constraint = self.sql(expression, "constraint")
@@ -1248,6 +1279,8 @@ class Generator:
if part
)
+ version = self.sql(expression, "version")
+ version = f" {version}" if version else ""
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=" ")
@@ -1256,10 +1289,8 @@ class Generator:
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
- system_time = expression.args.get("system_time")
- system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
- return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
+ return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@@ -1314,6 +1345,12 @@ class Generator:
nulls = ""
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
+ def version_sql(self, expression: exp.Version) -> str:
+ this = f"FOR {expression.name}"
+ kind = expression.text("kind")
+ expr = self.sql(expression, "expression")
+ return f"{this} {kind} {expr}"
+
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
@@ -1323,12 +1360,13 @@ class Generator:
from_sql = self.sql(expression, "from")
where_sql = self.sql(expression, "where")
returning = self.sql(expression, "returning")
+ order = self.sql(expression, "order")
limit = self.sql(expression, "limit")
if self.RETURNING_END:
- expression_sql = f"{from_sql}{where_sql}{returning}{limit}"
+ expression_sql = f"{from_sql}{where_sql}{returning}"
else:
- expression_sql = f"{returning}{from_sql}{where_sql}{limit}"
- sql = f"UPDATE {this} SET {set_sql}{expression_sql}"
+ expression_sql = f"{returning}{from_sql}{where_sql}"
+ sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
@@ -1425,6 +1463,16 @@ class Generator:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"
+ def connect_sql(self, expression: exp.Connect) -> str:
+ start = self.sql(expression, "start")
+ start = self.seg(f"START WITH {start}") if start else ""
+ connect = self.sql(expression, "connect")
+ connect = self.seg(f"CONNECT BY {connect}")
+ return start + connect
+
+ def prior_sql(self, expression: exp.Prior) -> str:
+ return f"PRIOR {self.sql(expression, 'this')}"
+
def join_sql(self, expression: exp.Join) -> str:
op_sql = " ".join(
op
@@ -1667,6 +1715,7 @@ class Generator:
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
+ self.sql(expression, "connect"),
self.sql(expression, "match"),
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
@@ -1801,7 +1850,8 @@ class Generator:
def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
- return f"UNION{kind}"
+ by_name = " BY NAME" if expression.args.get("by_name") else ""
+ return f"UNION{kind}{by_name}"
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
@@ -2224,7 +2274,14 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
- actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
+ if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
+ actions = self.expressions(
+ expression,
+ key="actions",
+ prefix="ADD COLUMN ",
+ )
+ else:
+ actions = f"ADD {self.expressions(expression, key='actions')}"
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
@@ -2525,10 +2582,21 @@ class Generator:
return f"WHEN {matched}{source}{condition} THEN {then}"
def merge_sql(self, expression: exp.Merge) -> str:
- this = self.sql(expression, "this")
+ table = expression.this
+ table_alias = ""
+
+ hints = table.args.get("hints")
+ if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
+ # T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
+ table = table.copy()
+ table_alias = f" AS {self.sql(table.args['alias'].pop())}"
+
+ this = self.sql(table)
using = f"USING {self.sql(expression, 'using')}"
on = f"ON {self.sql(expression, 'on')}"
- return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}"
+ expressions = self.expressions(expression, sep=" ")
+
+ return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
@@ -2631,6 +2699,29 @@ class Generator:
options = f" {options}" if options else ""
return f"{kind}{this}{type_}{schema}{options}"
+ def nvl2_sql(self, expression: exp.Nvl2) -> str:
+ if self.NVL2_SUPPORTED:
+ return self.function_fallback_sql(expression)
+
+ case = exp.Case().when(
+ expression.this.is_(exp.null()).not_(copy=False),
+ expression.args["true"].copy(),
+ copy=False,
+ )
+ else_cond = expression.args.get("false")
+ if else_cond:
+ case.else_(else_cond.copy(), copy=False)
+
+ return self.sql(case)
+
+ def comprehension_sql(self, expression: exp.Comprehension) -> str:
+ this = self.sql(expression, "this")
+ expr = self.sql(expression, "expression")
+ iterator = self.sql(expression, "iterator")
+ condition = self.sql(expression, "condition")
+ condition = f" IF {condition}" if condition else ""
+ return f"{this} FOR {expr} IN {iterator}{condition}"
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index a863017..7335d1e 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -33,6 +33,15 @@ class AutoName(Enum):
return name
+class classproperty(property):
+ """
+ Similar to a normal property but works for class methods
+ """
+
+ def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
+ return classmethod(self.fget).__get__(None, owner)() # type: ignore
+
+
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
try:
@@ -137,9 +146,9 @@ def subclasses(
def apply_index_offset(
this: exp.Expression,
- expressions: t.List[t.Optional[E]],
+ expressions: t.List[E],
offset: int,
-) -> t.List[t.Optional[E]]:
+) -> t.List[E]:
"""
Applies an offset to a given integer literal expression.
@@ -170,15 +179,14 @@ def apply_index_offset(
):
return expressions
- if expression:
- if not expression.type:
- annotate_types(expression)
- if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
- logger.warning("Applying array index offset (%s)", offset)
- expression = simplify(
- exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
- )
- return [expression]
+ if not expression.type:
+ annotate_types(expression)
+ if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
+ logger.warning("Applying array index offset (%s)", offset)
+ expression = simplify(
+ exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
+ )
+ return [expression]
return expressions
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index 719a77e..ee48006 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1,2 +1,9 @@
from sqlglot.optimizer.optimizer import RULES, optimize
-from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
+from sqlglot.optimizer.scope import (
+ Scope,
+ build_scope,
+ find_all_in_scope,
+ find_in_scope,
+ traverse_scope,
+ walk_in_scope,
+)
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index e7cb80b..a429655 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for expr_type in expressions
},
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
+ exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
+ exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
+ exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
+ exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
@@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
+ NESTED_TYPES = {
+ exp.DataType.Type.ARRAY,
+ }
+
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
@@ -299,19 +308,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
- ) -> exp.DataType.Type:
- # We propagate the NULL / UNKNOWN types upwards if found
- if isinstance(type1, exp.DataType):
- type1 = type1.this
- if isinstance(type2, exp.DataType):
- type2 = type2.this
+ ) -> exp.DataType | exp.DataType.Type:
+ type1_value = type1.this if isinstance(type1, exp.DataType) else type1
+ type2_value = type2.this if isinstance(type2, exp.DataType) else type2
- if exp.DataType.Type.NULL in (type1, type2):
+ # We propagate the NULL / UNKNOWN types upwards if found
+ if exp.DataType.Type.NULL in (type1_value, type2_value):
return exp.DataType.Type.NULL
- if exp.DataType.Type.UNKNOWN in (type1, type2):
+ if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
+ if type1_value in self.NESTED_TYPES:
+ return type1
+ if type2_value in self.NESTED_TYPES:
+ return type2
+
+ return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
@@ -368,7 +380,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._annotate_args(expression)
@t.no_type_check
- def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
+ def _annotate_by_args(
+ self, expression: E, *args: str, promote: bool = False, array: bool = False
+ ) -> E:
self._annotate_args(expression)
expressions: t.List[exp.Expression] = []
@@ -388,4 +402,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE
+ if array:
+ expression.type = exp.DataType(
+ this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
+ )
+
return expression
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index af42f25..1ab7768 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -142,13 +142,14 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
- parent = scope.expression.parent
+ # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
+ to_replace = scope.expression.parent.unwrap()
name, cte = _new_cte(scope, existing_ctes, taken)
+ table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
+ table.set("joins", to_replace.args.get("joins"))
- table = exp.alias_(exp.table_(name), alias=parent.alias or name)
- table.set("joins", parent.args.get("joins"))
+ to_replace.replace(table)
- parent.replace(table)
return cte
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 7b3b2b1..9d401fc 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -72,8 +72,13 @@ def normalize(expression):
if not any(join.args.get(k) for k in JOIN_ATTRS):
join.set("kind", "CROSS")
- if join.kind != "CROSS":
+ if join.kind == "CROSS":
+ join.set("on", None)
+ else:
join.set("kind", None)
+
+ if not join.args.get("on") and not join.args.get("using"):
+ join.set("on", exp.true())
return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 58b988d..f7348b5 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -1,6 +1,6 @@
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
-from sqlglot.optimizer.scope import build_scope
+from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
@@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
- node.where(replace_aliases(node, predicate), copy=False)
+ inner_predicate = replace_aliases(node, predicate)
+ if find_in_scope(inner_predicate, exp.AggFunc):
+ node.having(inner_predicate, copy=False)
+ else:
+ node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, scope, scope_ref_count):
@@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
- node.where(replace_aliases(node, predicate), copy=False)
+ inner_predicate = replace_aliases(node, predicate)
+ if find_in_scope(inner_predicate, exp.AggFunc):
+ node.having(inner_predicate, copy=False)
+ else:
+ node.where(inner_predicate, copy=False)
def nodes_for_predicate(predicate, sources, scope_ref_count):
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index fb12384..435899a 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -6,7 +6,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.helper import find_new_name
+from sqlglot.helper import ensure_collection, find_new_name
logger = logging.getLogger("sqlglot")
@@ -141,38 +141,10 @@ class Scope:
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
- """
- Returns the first node in this scope which matches at least one of the specified types.
-
- This does NOT traverse into subscopes.
-
- Args:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
-
- Returns:
- exp.Expression: the node which matches the criteria or None if no node matching
- the criteria was found.
- """
- return next(self.find_all(*expression_types, bfs=bfs), None)
+ return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
- """
- Returns a generator object which visits all nodes in this scope and only yields those that
- match at least one of the specified expression types.
-
- This does NOT traverse into subscopes.
-
- Args:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
-
- Yields:
- exp.Expression: nodes
- """
- for expression, *_ in self.walk(bfs=bfs):
- if isinstance(expression, expression_types):
- yield expression
+ return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
"""
@@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)
+
+
+def find_all_in_scope(expression, expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this scope and only yields those that
+ match at least one of the specified expression types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression (exp.Expression):
+ expression_types (tuple[type]|type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Yields:
+ exp.Expression: nodes
+ """
+ for expression, *_ in walk_in_scope(expression, bfs=bfs):
+ if isinstance(expression, tuple(ensure_collection(expression_types))):
+ yield expression
+
+
+def find_in_scope(expression, expression_types, bfs=True):
+ """
+ Returns the first node in this scope which matches at least one of the specified types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression (exp.Expression):
+ expression_types (tuple[type]|type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Returns:
+ exp.Expression: the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index e550603..3974ea4 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -69,10 +69,10 @@ def simplify(expression):
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_compliments(node, root)
+ node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
- node = simplify_coalesce(node)
if root:
expression.replace(node)
@@ -350,7 +350,8 @@ def absorb_and_eliminate(expression, root=True):
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
- elif isinstance(expression, exp.Neg):
+
+ if isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
@@ -430,13 +431,14 @@ def simplify_parens(expression):
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
- or isinstance(this, exp.Predicate)
+ or isinstance(parent, exp.Paren)
or not isinstance(this, exp.Binary)
+ or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
):
- return expression.this
+ return this
return expression
@@ -488,18 +490,20 @@ def simplify_coalesce(expression):
coalesce = coalesce if coalesce.expressions else coalesce.this
# This expression is more complex than when we started, but it will get simplified further
- return exp.or_(
- exp.and_(
- coalesce.is_(exp.null()).not_(copy=False),
- expression.copy(),
- copy=False,
- ),
- exp.and_(
- coalesce.is_(exp.null()),
- type(expression)(this=arg.copy(), expression=other.copy()),
+ return exp.paren(
+ exp.or_(
+ exp.and_(
+ coalesce.is_(exp.null()).not_(copy=False),
+ expression.copy(),
+ copy=False,
+ ),
+ exp.and_(
+ coalesce.is_(exp.null()),
+ type(expression)(this=arg.copy(), expression=other.copy()),
+ copy=False,
+ ),
copy=False,
- ),
- copy=False,
+ )
)
@@ -642,7 +646,7 @@ def _flat_simplify(expression, simplifier, root=True):
for b in queue:
result = simplifier(expression, a, b)
- if result:
+ if result and result is not expression:
queue.remove(b)
queue.appendleft(result)
break
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 3db4453..f8690d5 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -136,6 +136,7 @@ class Parser(metaclass=_Parser):
TokenType.UINT128,
TokenType.INT256,
TokenType.UINT256,
+ TokenType.MEDIUMINT,
TokenType.FIXEDSTRING,
TokenType.FLOAT,
TokenType.DOUBLE,
@@ -186,6 +187,7 @@ class Parser(metaclass=_Parser):
TokenType.SMALLSERIAL,
TokenType.BIGSERIAL,
TokenType.XML,
+ TokenType.YEAR,
TokenType.UNIQUEIDENTIFIER,
TokenType.USERDEFINED,
TokenType.MONEY,
@@ -194,9 +196,12 @@ class Parser(metaclass=_Parser):
TokenType.IMAGE,
TokenType.VARIANT,
TokenType.OBJECT,
+ TokenType.OBJECT_IDENTIFIER,
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
+ TokenType.UNKNOWN,
+ TokenType.NULL,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
}
@@ -332,6 +337,7 @@ class Parser(metaclass=_Parser):
TokenType.INDEX,
TokenType.ISNULL,
TokenType.ILIKE,
+ TokenType.INSERT,
TokenType.LIKE,
TokenType.MERGE,
TokenType.OFFSET,
@@ -487,7 +493,7 @@ class Parser(metaclass=_Parser):
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
- exp.DataType: lambda self: self._parse_types(),
+ exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
exp.Expression: lambda self: self._parse_statement(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
@@ -523,9 +529,6 @@ class Parser(metaclass=_Parser):
TokenType.DESC: lambda self: self._parse_describe(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
- TokenType.FROM: lambda self: exp.select("*").from_(
- t.cast(exp.From, self._parse_from(skip_from_token=True))
- ),
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD: lambda self: self._parse_load(),
TokenType.MERGE: lambda self: self._parse_merge(),
@@ -578,7 +581,7 @@ class Parser(metaclass=_Parser):
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
TokenType.PARAMETER: lambda self: self._parse_parameter(),
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
- if self._match_set((TokenType.NUMBER, TokenType.VAR))
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None,
}
@@ -593,6 +596,7 @@ class Parser(metaclass=_Parser):
TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
TokenType.RLIKE: binary_range_parser(exp.RegexpLike),
TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo),
+ TokenType.FOR: lambda self, this: self._parse_comprehension(this),
}
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
@@ -684,6 +688,12 @@ class Parser(metaclass=_Parser):
exp.CommentColumnConstraint, this=self._parse_string()
),
"COMPRESS": lambda self: self._parse_compress(),
+ "CLUSTERED": lambda self: self.expression(
+ exp.ClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered)
+ ),
+ "NONCLUSTERED": lambda self: self.expression(
+ exp.NonClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered)
+ ),
"DEFAULT": lambda self: self.expression(
exp.DefaultColumnConstraint, this=self._parse_bitwise()
),
@@ -698,8 +708,11 @@ class Parser(metaclass=_Parser):
"LIKE": lambda self: self._parse_create_like(),
"NOT": lambda self: self._parse_not_constraint(),
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
- "ON": lambda self: self._match(TokenType.UPDATE)
- and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
+ "ON": lambda self: (
+ self._match(TokenType.UPDATE)
+ and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function())
+ )
+ or self.expression(exp.OnProperty, this=self._parse_id_var()),
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
"REFERENCES": lambda self: self._parse_references(match=False),
@@ -709,6 +722,9 @@ class Parser(metaclass=_Parser):
"TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]),
"UNIQUE": lambda self: self._parse_unique(),
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
+ "WITH": lambda self: self.expression(
+ exp.Properties, expressions=self._parse_wrapped_csv(self._parse_property)
+ ),
}
ALTER_PARSERS = {
@@ -728,6 +744,11 @@ class Parser(metaclass=_Parser):
"NEXT": lambda self: self._parse_next_value_for(),
}
+ INVALID_FUNC_NAME_TOKENS = {
+ TokenType.IDENTIFIER,
+ TokenType.STRING,
+ }
+
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
FUNCTION_PARSERS = {
@@ -774,6 +795,8 @@ class Parser(metaclass=_Parser):
self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
),
TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)),
+ TokenType.CONNECT_BY: lambda self: ("connect", self._parse_connect(skip_start_token=True)),
+ TokenType.START_WITH: lambda self: ("connect", self._parse_connect()),
}
SET_PARSERS = {
@@ -815,6 +838,8 @@ class Parser(metaclass=_Parser):
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
+ DISTINCT_TOKENS = {TokenType.DISTINCT}
+
STRICT_CAST = True
# A NULL arg in CONCAT yields NULL by default
@@ -826,6 +851,11 @@ class Parser(metaclass=_Parser):
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
+ SUPPORTS_USER_DEFINED_TYPES = True
+
+ # Whether or not ADD is present for each column added by ALTER TABLE
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = True
+
__slots__ = (
"error_level",
"error_message_context",
@@ -838,9 +868,11 @@ class Parser(metaclass=_Parser):
"_next",
"_prev",
"_prev_comments",
+ "_tokenizer",
)
# Autofilled
+ TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer
INDEX_OFFSET: int = 0
UNNEST_COLUMN_ONLY: bool = False
ALIAS_POST_TABLESAMPLE: bool = False
@@ -863,6 +895,7 @@ class Parser(metaclass=_Parser):
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
self.max_errors = max_errors
+ self._tokenizer = self.TOKENIZER_CLASS()
self.reset()
def reset(self):
@@ -1148,7 +1181,7 @@ class Parser(metaclass=_Parser):
expression = self._parse_set_operations(expression) if expression else self._parse_select()
return self._parse_query_modifiers(expression)
- def _parse_drop(self) -> exp.Drop | exp.Command:
+ def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match_text_seq("MATERIALIZED")
@@ -1160,7 +1193,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Drop,
comments=start.comments,
- exists=self._parse_exists(),
+ exists=exists or self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
temporary=temporary,
@@ -1274,6 +1307,8 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
+ shallow = self._match_text_seq("SHALLOW")
+
if self._match_text_seq("CLONE"):
clone = self._parse_table(schema=True)
when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
@@ -1285,7 +1320,12 @@ class Parser(metaclass=_Parser):
clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
self._match(TokenType.R_PAREN)
clone = self.expression(
- exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression
+ exp.Clone,
+ this=clone,
+ when=when,
+ kind=clone_kind,
+ shallow=shallow,
+ expression=clone_expression,
)
return self.expression(
@@ -1349,7 +1389,11 @@ class Parser(metaclass=_Parser):
if assignment:
key = self._parse_var_or_string()
self._match(TokenType.EQ)
- return self.expression(exp.Property, this=key, value=self._parse_column())
+ return self.expression(
+ exp.Property,
+ this=key,
+ value=self._parse_column() or self._parse_var(any_token=True),
+ )
return None
@@ -1409,7 +1453,7 @@ class Parser(metaclass=_Parser):
def _parse_with_property(
self,
- ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
+ ) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
@@ -1622,7 +1666,7 @@ class Parser(metaclass=_Parser):
override=override,
)
- def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_partition_by(self) -> t.List[exp.Expression]:
if self._match(TokenType.PARTITION_BY):
return self._parse_csv(self._parse_conjunction)
return []
@@ -1652,9 +1696,9 @@ class Parser(metaclass=_Parser):
def _parse_on_property(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
return exp.OnCommitProperty()
- elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
+ if self._match_text_seq("COMMIT", "DELETE", "ROWS"):
return exp.OnCommitProperty(delete=True)
- return None
+ return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var()))
def _parse_distkey(self) -> exp.DistKeyProperty:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@@ -1709,8 +1753,10 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
- this = self._parse_table()
- return self.expression(exp.Describe, this=this, kind=kind)
+ this = self._parse_table(schema=True)
+ properties = self._parse_properties()
+ expressions = properties.expressions if properties else None
+ return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions)
def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
@@ -1741,6 +1787,7 @@ class Parser(metaclass=_Parser):
exp.Insert,
comments=comments,
this=this,
+ by_name=self._match_text_seq("BY", "NAME"),
exists=self._parse_exists(),
partition=self._parse_partition(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
@@ -1895,6 +1942,7 @@ class Parser(metaclass=_Parser):
"from": self._parse_from(joins=True),
"where": self._parse_where(),
"returning": returning or self._parse_returning(),
+ "order": self._parse_order(),
"limit": self._parse_limit(),
},
)
@@ -1948,13 +1996,14 @@ class Parser(metaclass=_Parser):
# https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
- def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_projections(self) -> t.List[exp.Expression]:
return self._parse_expressions()
def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
) -> t.Optional[exp.Expression]:
cte = self._parse_with()
+
if cte:
this = self._parse_statement()
@@ -1967,12 +2016,18 @@ class Parser(metaclass=_Parser):
else:
self.raise_error(f"{this.key} does not support CTE")
this = cte
- elif self._match(TokenType.SELECT):
+
+ return this
+
+ # duckdb supports leading with FROM x
+ from_ = self._parse_from() if self._match(TokenType.FROM, advance=False) else None
+
+ if self._match(TokenType.SELECT):
comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
- distinct = self._match(TokenType.DISTINCT)
+ distinct = self._match_set(self.DISTINCT_TOKENS)
kind = (
self._match(TokenType.ALIAS)
@@ -2006,7 +2061,9 @@ class Parser(metaclass=_Parser):
if into:
this.set("into", into)
- from_ = self._parse_from()
+ if not from_:
+ from_ = self._parse_from()
+
if from_:
this.set("from", from_)
@@ -2033,6 +2090,8 @@ class Parser(metaclass=_Parser):
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
+ elif from_:
+ this = exp.select("*").from_(from_.this, copy=False)
else:
this = None
@@ -2491,6 +2550,11 @@ class Parser(metaclass=_Parser):
if schema:
return self._parse_schema(this=this)
+ version = self._parse_version()
+
+ if version:
+ this.set("version", version)
+
if self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
@@ -2498,11 +2562,11 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
+ this.set("hints", self._parse_table_hints())
+
if not this.args.get("pivots"):
this.set("pivots", self._parse_pivots())
- this.set("hints", self._parse_table_hints())
-
if not self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
@@ -2516,6 +2580,37 @@ class Parser(metaclass=_Parser):
return this
+ def _parse_version(self) -> t.Optional[exp.Version]:
+ if self._match(TokenType.TIMESTAMP_SNAPSHOT):
+ this = "TIMESTAMP"
+ elif self._match(TokenType.VERSION_SNAPSHOT):
+ this = "VERSION"
+ else:
+ return None
+
+ if self._match_set((TokenType.FROM, TokenType.BETWEEN)):
+ kind = self._prev.text.upper()
+ start = self._parse_bitwise()
+ self._match_texts(("TO", "AND"))
+ end = self._parse_bitwise()
+ expression: t.Optional[exp.Expression] = self.expression(
+ exp.Tuple, expressions=[start, end]
+ )
+ elif self._match_text_seq("CONTAINED", "IN"):
+ kind = "CONTAINED IN"
+ expression = self.expression(
+ exp.Tuple, expressions=self._parse_wrapped_csv(self._parse_bitwise)
+ )
+ elif self._match(TokenType.ALL):
+ kind = "ALL"
+ expression = None
+ else:
+ self._match_text_seq("AS", "OF")
+ kind = "AS OF"
+ expression = self._parse_type()
+
+ return self.expression(exp.Version, this=this, expression=expression, kind=kind)
+
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
if not self._match(TokenType.UNNEST):
return None
@@ -2760,7 +2855,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Group, **elements) # type: ignore
- def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ def _parse_grouping_sets(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.GROUPING_SETS):
return None
@@ -2784,6 +2879,22 @@ class Parser(metaclass=_Parser):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
+ def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]:
+ if skip_start_token:
+ start = None
+ elif self._match(TokenType.START_WITH):
+ start = self._parse_conjunction()
+ else:
+ return None
+
+ self._match(TokenType.CONNECT_BY)
+ self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression(
+ exp.Prior, this=self._parse_bitwise()
+ )
+ connect = self._parse_conjunction()
+ self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR")
+ return self.expression(exp.Connect, start=start, connect=connect)
+
def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
@@ -2929,6 +3040,7 @@ class Parser(metaclass=_Parser):
expression,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
+ by_name=self._match_text_seq("BY", "NAME"),
expression=self._parse_set_operations(self._parse_select(nested=True)),
)
@@ -3017,6 +3129,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Escape, this=this, expression=self._parse_string())
def _parse_interval(self) -> t.Optional[exp.Interval]:
+ index = self._index
+
if not self._match(TokenType.INTERVAL):
return None
@@ -3025,7 +3139,11 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_term()
- unit = self._parse_function() or self._parse_var()
+ if not this:
+ self._retreat(index)
+ return None
+
+ unit = self._parse_function() or self._parse_var(any_token=True)
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
@@ -3036,12 +3154,12 @@ class Parser(metaclass=_Parser):
if len(parts) == 2:
if unit:
- # this is not actually a unit, it's something else
+ # This is not actually a unit, it's something else (e.g. a "window side")
unit = None
self._retreat(self._index - 1)
- else:
- this = exp.Literal.string(parts[0])
- unit = self.expression(exp.Var, this=parts[1])
+
+ this = exp.Literal.string(parts[0])
+ unit = self.expression(exp.Var, this=parts[1])
return self.expression(exp.Interval, this=this, unit=unit)
@@ -3087,7 +3205,7 @@ class Parser(metaclass=_Parser):
return interval
index = self._index
- data_type = self._parse_types(check_func=True)
+ data_type = self._parse_types(check_func=True, allow_identifiers=False)
this = self._parse_column()
if data_type:
@@ -3103,30 +3221,50 @@ class Parser(metaclass=_Parser):
return this
- def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]:
+ def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
this = self._parse_type()
if not this:
return None
return self.expression(
- exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
+ exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True)
)
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
index = self._index
prefix = self._match_text_seq("SYSUDTLIB", ".")
if not self._match_set(self.TYPE_TOKENS):
- return None
+ identifier = allow_identifiers and self._parse_id_var(
+ any_token=False, tokens=(TokenType.VAR,)
+ )
+
+ if identifier:
+ tokens = self._tokenizer.tokenize(identifier.name)
+
+ if len(tokens) != 1:
+ self.raise_error("Unexpected identifier", self._prev)
+
+ if tokens[0].token_type in self.TYPE_TOKENS:
+ self._prev = tokens[0]
+ elif self.SUPPORTS_USER_DEFINED_TYPES:
+ return identifier
+ else:
+ return None
+ else:
+ return None
type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
return self.expression(exp.PseudoType, this=self._prev.text)
+ if type_token == TokenType.OBJECT_IDENTIFIER:
+ return self.expression(exp.ObjectIdentifier, this=self._prev.text)
+
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
expressions = None
@@ -3137,7 +3275,9 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_struct_types)
elif nested:
expressions = self._parse_csv(
- lambda: self._parse_types(check_func=check_func, schema=schema)
+ lambda: self._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
)
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_equality)
@@ -3151,14 +3291,16 @@ class Parser(metaclass=_Parser):
maybe_func = True
this: t.Optional[exp.Expression] = None
- values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
+ values: t.Optional[t.List[exp.Expression]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
else:
expressions = self._parse_csv(
- lambda: self._parse_types(check_func=check_func, schema=schema)
+ lambda: self._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
)
if not self._match(TokenType.GT):
@@ -3355,7 +3497,7 @@ class Parser(metaclass=_Parser):
upper = this.upper()
parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper)
- if optional_parens and parser:
+ if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS:
self._advance()
return parser(self)
@@ -3442,7 +3584,9 @@ class Parser(metaclass=_Parser):
index = self._index
if self._match(TokenType.L_PAREN):
- expressions = self._parse_csv(self._parse_id_var)
+ expressions = t.cast(
+ t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var)
+ )
if not self._match(TokenType.R_PAREN):
self._retreat(index)
@@ -3481,14 +3625,14 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return this
- args = self._parse_csv(
- lambda: self._parse_constraint()
- or self._parse_column_def(self._parse_field(any_token=True))
- )
+ args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def())
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
+ def _parse_field_def(self) -> t.Optional[exp.Expression]:
+ return self._parse_column_def(self._parse_field(any_token=True))
+
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
@@ -3499,7 +3643,18 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("FOR", "ORDINALITY"):
return self.expression(exp.ColumnDef, this=this, ordinality=True)
- constraints = []
+ constraints: t.List[exp.Expression] = []
+
+ if not kind and self._match(TokenType.ALIAS):
+ constraints.append(
+ self.expression(
+ exp.ComputedColumnConstraint,
+ this=self._parse_conjunction(),
+ persisted=self._match_text_seq("PERSISTED"),
+ not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
+ )
+ )
+
while True:
constraint = self._parse_column_constraint()
if not constraint:
@@ -3553,7 +3708,7 @@ class Parser(metaclass=_Parser):
identity = self._match_text_seq("IDENTITY")
if self._match(TokenType.L_PAREN):
- if self._match_text_seq("START", "WITH"):
+ if self._match(TokenType.START_WITH):
this.set("start", self._parse_bitwise())
if self._match_text_seq("INCREMENT", "BY"):
this.set("increment", self._parse_bitwise())
@@ -3580,11 +3735,13 @@ class Parser(metaclass=_Parser):
def _parse_not_constraint(
self,
- ) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]:
+ ) -> t.Optional[exp.Expression]:
if self._match_text_seq("NULL"):
return self.expression(exp.NotNullColumnConstraint)
if self._match_text_seq("CASESPECIFIC"):
return self.expression(exp.CaseSpecificColumnConstraint, not_=True)
+ if self._match_text_seq("FOR", "REPLICATION"):
+ return self.expression(exp.NotForReplicationColumnConstraint)
return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
@@ -3729,7 +3886,7 @@ class Parser(metaclass=_Parser):
bracket_kind = self._prev.token_type
if self._match(TokenType.COLON):
- expressions: t.List[t.Optional[exp.Expression]] = [
+ expressions: t.List[exp.Expression] = [
self.expression(exp.Slice, expression=self._parse_conjunction())
]
else:
@@ -3844,17 +4001,17 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.ALIAS):
if self._match(TokenType.COMMA):
- return self.expression(
- exp.CastToStrType, this=this, expression=self._parse_string()
- )
- else:
- self.raise_error("Expected AS after CAST")
+ return self.expression(exp.CastToStrType, this=this, to=self._parse_string())
+
+ self.raise_error("Expected AS after CAST")
fmt = None
to = self._parse_types()
if not to:
self.raise_error("Expected TYPE after CAST")
+ elif isinstance(to, exp.Identifier):
+ to = exp.DataType.build(to.name, udt=True)
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
@@ -3908,7 +4065,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.COMMA):
args.extend(self._parse_csv(self._parse_conjunction))
else:
- args = self._parse_csv(self._parse_conjunction)
+ args = self._parse_csv(self._parse_conjunction) # type: ignore
index = self._index
if not self._match(TokenType.R_PAREN) and args:
@@ -3991,10 +4148,10 @@ class Parser(metaclass=_Parser):
def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
self._match_text_seq("KEY")
- key = self._parse_field()
- self._match(TokenType.COLON)
+ key = self._parse_column()
+ self._match_set((TokenType.COLON, TokenType.COMMA))
self._match_text_seq("VALUE")
- value = self._parse_field()
+ value = self._parse_bitwise()
if not key and not value:
return None
@@ -4116,7 +4273,7 @@ class Parser(metaclass=_Parser):
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
- args = self._parse_csv(self._parse_bitwise)
+ args = t.cast(t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise))
if self._match(TokenType.FROM):
args.append(self._parse_bitwise())
@@ -4149,7 +4306,7 @@ class Parser(metaclass=_Parser):
exp.Trim, this=this, position=position, expression=expression, collation=collation
)
- def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]:
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
def _parse_named_window(self) -> t.Optional[exp.Expression]:
@@ -4216,8 +4373,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("LAST"):
first = False
- partition = self._parse_partition_by()
- order = self._parse_order()
+ partition, order = self._parse_partition_and_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
if kind:
@@ -4256,6 +4412,11 @@ class Parser(metaclass=_Parser):
return window
+ def _parse_partition_and_order(
+ self,
+ ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
+ return self._parse_partition_by(), self._parse_order()
+
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)
@@ -4377,14 +4538,14 @@ class Parser(metaclass=_Parser):
self._advance(-1)
return None
- def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ def _parse_except(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_column)
return self._parse_csv(self._parse_column)
- def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.REPLACE):
return None
if self._match(TokenType.L_PAREN, advance=False):
@@ -4393,7 +4554,7 @@ class Parser(metaclass=_Parser):
def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
- ) -> t.List[t.Optional[exp.Expression]]:
+ ) -> t.List[exp.Expression]:
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
@@ -4420,12 +4581,12 @@ class Parser(metaclass=_Parser):
return this
- def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return self._parse_wrapped_csv(self._parse_id_var, optional=optional)
def _parse_wrapped_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False
- ) -> t.List[t.Optional[exp.Expression]]:
+ ) -> t.List[exp.Expression]:
return self._parse_wrapped(
lambda: self._parse_csv(parse_method, sep=sep), optional=optional
)
@@ -4439,7 +4600,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return parse_result
- def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_expressions(self) -> t.List[exp.Expression]:
return self._parse_csv(self._parse_expression)
def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
@@ -4498,7 +4659,7 @@ class Parser(metaclass=_Parser):
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
- expression = self._parse_column_def(self._parse_field(any_token=True))
+ expression = self._parse_field_def()
if expression:
expression.set("exists", exists_column)
@@ -4549,13 +4710,16 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AddConstraint, this=this, expression=expression)
- def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_alter_table_add(self) -> t.List[exp.Expression]:
index = self._index - 1
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
return self._parse_csv(self._parse_add_constraint)
self._retreat(index)
+ if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"):
+ return self._parse_csv(self._parse_field_def)
+
return self._parse_csv(self._parse_add_column)
def _parse_alter_table_alter(self) -> exp.AlterColumn:
@@ -4576,7 +4740,7 @@ class Parser(metaclass=_Parser):
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
- def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_alter_table_drop(self) -> t.List[exp.Expression]:
index = self._index - 1
partition_exists = self._parse_exists()
@@ -4619,6 +4783,9 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
target = self._parse_table()
+ if target and self._match(TokenType.ALIAS, advance=False):
+ target.set("alias", self._parse_table_alias())
+
self._match(TokenType.USING)
using = self._parse_table()
@@ -4685,8 +4852,7 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
if parser:
return parser(self)
- self._advance()
- return self.expression(exp.Show, this=self._prev.text.upper())
+ return self._parse_as_command(self._prev)
def _parse_set_item_assignment(
self, kind: t.Optional[str] = None
@@ -4786,6 +4952,19 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.DictRange, this=this, min=min, max=max)
+ def _parse_comprehension(self, this: exp.Expression) -> exp.Comprehension:
+ expression = self._parse_column()
+ self._match(TokenType.IN)
+ iterator = self._parse_column()
+ condition = self._parse_conjunction() if self._match_text_seq("IF") else None
+ return self.expression(
+ exp.Comprehension,
+ this=this,
+ expression=expression,
+ iterator=iterator,
+ condition=condition,
+ )
+
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index d278dbf..83b97d6 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -48,6 +48,7 @@ class TokenType(AutoName):
HASH_ARROW = auto()
DHASH_ARROW = auto()
LR_ARROW = auto()
+ DAT = auto()
LT_AT = auto()
AT_GT = auto()
DOLLAR = auto()
@@ -84,6 +85,7 @@ class TokenType(AutoName):
UTINYINT = auto()
SMALLINT = auto()
USMALLINT = auto()
+ MEDIUMINT = auto()
INT = auto()
UINT = auto()
BIGINT = auto()
@@ -140,6 +142,7 @@ class TokenType(AutoName):
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
+ YEAR = auto()
UNIQUEIDENTIFIER = auto()
USERDEFINED = auto()
MONEY = auto()
@@ -157,6 +160,7 @@ class TokenType(AutoName):
FIXEDSTRING = auto()
LOWCARDINALITY = auto()
NESTED = auto()
+ UNKNOWN = auto()
# keywords
ALIAS = auto()
@@ -180,6 +184,7 @@ class TokenType(AutoName):
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
+ CONNECT_BY = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
@@ -256,6 +261,7 @@ class TokenType(AutoName):
NEXT = auto()
NOTNULL = auto()
NULL = auto()
+ OBJECT_IDENTIFIER = auto()
OFFSET = auto()
ON = auto()
ORDER_BY = auto()
@@ -298,6 +304,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
+ START_WITH = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
@@ -319,6 +326,8 @@ class TokenType(AutoName):
WINDOW = auto()
WITH = auto()
UNIQUE = auto()
+ VERSION_SNAPSHOT = auto()
+ TIMESTAMP_SNAPSHOT = auto()
class Token:
@@ -530,6 +539,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
+ "CONNECT BY": TokenType.CONNECT_BY,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
@@ -636,6 +646,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
+ "START WITH": TokenType.START_WITH,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
@@ -643,6 +654,7 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"UNION": TokenType.UNION,
+ "UNKNOWN": TokenType.UNKNOWN,
"UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
@@ -739,6 +751,8 @@ class Tokenizer(metaclass=_Tokenizer):
"TRUNCATE": TokenType.COMMAND,
"VACUUM": TokenType.COMMAND,
"USER-DEFINED": TokenType.USERDEFINED,
+ "FOR VERSION": TokenType.VERSION_SNAPSHOT,
+ "FOR TIMESTAMP": TokenType.TIMESTAMP_SNAPSHOT,
}
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
@@ -941,8 +955,8 @@ class Tokenizer(metaclass=_Tokenizer):
if result == TrieResult.EXISTS:
word = chars
+ end = self._current + size
size += 1
- end = self._current - 1 + size
if end < self.size:
char = self.sql[end]
@@ -961,21 +975,20 @@ class Tokenizer(metaclass=_Tokenizer):
char = ""
chars = " "
- if not word:
- if self._char in self.SINGLE_TOKENS:
- self._add(self.SINGLE_TOKENS[self._char], text=self._char)
+ if word:
+ if self._scan_string(word):
return
- self._scan_var()
- return
-
- if self._scan_string(word):
- return
- if self._scan_comment(word):
+ if self._scan_comment(word):
+ return
+ if prev_space or single_token or not char:
+ self._advance(size - 1)
+ word = word.upper()
+ self._add(self.KEYWORDS[word], text=word)
+ return
+ if self._char in self.SINGLE_TOKENS:
+ self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return
-
- self._advance(size - 1)
- word = word.upper()
- self._add(self.KEYWORDS[word], text=word)
+ self._scan_var()
def _scan_comment(self, comment_start: str) -> bool:
if comment_start not in self._COMMENTS:
@@ -1053,8 +1066,8 @@ class Tokenizer(metaclass=_Tokenizer):
elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
- self._add(TokenType.NUMBER, number_text)
- return self._advance(-len(literal))
+ self._advance(-len(literal))
+ return self._add(TokenType.NUMBER, number_text)
else:
return self._add(TokenType.NUMBER)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 7c7c2a7..48ea8dc 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -68,11 +68,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
if order:
window.set("order", order.pop().copy())
+ else:
+ window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
window = exp.alias_(window, row_number)
expression.select(window, copy=False)
- return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
+ return (
+ exp.select(*outer_selects)
+ .from_(expression.subquery())
+ .where(exp.column(row_number).eq(1))
+ )
return expression
@@ -126,7 +132,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
"""
for node in expression.find_all(exp.DataType):
node.set(
- "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)]
+ "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
)
return expression