From f73e9af131151f1e058446361c35b05c4c90bf10 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 7 Sep 2023 13:39:48 +0200 Subject: Merging upstream version 18.2.0. Signed-off-by: Daniel Baumann --- sqlglot/dataframe/README.md | 34 ++-- sqlglot/dataframe/sql/column.py | 22 +- sqlglot/dataframe/sql/dataframe.py | 34 +++- sqlglot/dataframe/sql/functions.py | 8 +- sqlglot/dataframe/sql/normalize.py | 3 +- sqlglot/dataframe/sql/readwriter.py | 23 ++- sqlglot/dataframe/sql/session.py | 65 ++++-- sqlglot/dataframe/sql/window.py | 4 +- sqlglot/dialects/bigquery.py | 86 ++++++-- sqlglot/dialects/clickhouse.py | 52 ++++- sqlglot/dialects/databricks.py | 15 +- sqlglot/dialects/dialect.py | 20 +- sqlglot/dialects/doris.py | 1 - sqlglot/dialects/drill.py | 9 +- sqlglot/dialects/duckdb.py | 38 ++-- sqlglot/dialects/hive.py | 55 +++-- sqlglot/dialects/mysql.py | 32 ++- sqlglot/dialects/oracle.py | 11 +- sqlglot/dialects/postgres.py | 38 +++- sqlglot/dialects/presto.py | 54 ++++- sqlglot/dialects/redshift.py | 14 +- sqlglot/dialects/snowflake.py | 78 +++++++- sqlglot/dialects/spark.py | 10 + sqlglot/dialects/spark2.py | 31 ++- sqlglot/dialects/sqlite.py | 5 +- sqlglot/dialects/teradata.py | 4 + sqlglot/dialects/trino.py | 3 + sqlglot/dialects/tsql.py | 157 +++++++++------ sqlglot/expressions.py | 242 ++++++++++++++++++---- sqlglot/generator.py | 149 +++++++++++--- sqlglot/helper.py | 30 ++- sqlglot/optimizer/__init__.py | 9 +- sqlglot/optimizer/annotate_types.py | 39 +++- sqlglot/optimizer/eliminate_subqueries.py | 9 +- sqlglot/optimizer/optimize_joins.py | 7 +- sqlglot/optimizer/pushdown_predicates.py | 14 +- sqlglot/optimizer/scope.py | 72 ++++--- sqlglot/optimizer/simplify.py | 36 ++-- sqlglot/parser.py | 321 +++++++++++++++++++++++------- sqlglot/tokens.py | 45 +++-- sqlglot/transforms.py | 10 +- 41 files changed, 1424 insertions(+), 465 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md index 86fdc4b..adde9a1 100644 --- a/sqlglot/dataframe/README.md +++ b/sqlglot/dataframe/README.md @@ -21,10 +21,12 @@ Currently many of the common operations are covered and more functionality will * Ex: `['cola', 'colb']` * The lack of types may limit functionality in future releases. * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally. +* If your output SQL dialect is not Spark, then configure the SparkSession to use that dialect + * Ex: `SparkSession().builder.config("sqlframe.dialect", "bigquery").getOrCreate()` + * See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects. * Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command. - * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects. - * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects. - * Ex: `.sql(pretty=True, dialect='bigquery')` + * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects. + * Ex: `.sql(pretty=True)` ## Examples @@ -33,6 +35,8 @@ import sqlglot from sqlglot.dataframe.sql.session import SparkSession from sqlglot.dataframe.sql import functions as F +dialect = "spark" + sqlglot.schema.add_table( 'employee', { @@ -41,10 +45,10 @@ sqlglot.schema.add_table( 'lname': 'STRING', 'age': 'INT', }, - dialect="spark", + dialect=dialect, ) # Register the table structure prior to reading from the table -spark = SparkSession() +spark = SparkSession.builder.config("sqlframe.dialect", dialect).getOrCreate() df = ( spark @@ -53,7 +57,7 @@ df = ( .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) ) -print(df.sql(pretty=True)) # Spark will be the dialect used by default +print(df.sql(pretty=True)) ``` ```sparksql @@ -81,7 +85,7 @@ class ExternalSchema(Schema): sqlglot.schema = ExternalSchema() -spark = SparkSession() +spark = SparkSession() # Spark will be used by default is not specific in SparkSession config df = ( spark @@ -119,11 +123,14 @@ schema = types.StructType([ ]) sql_statements = ( - SparkSession() + SparkSession + .builder + .config("sqlframe.dialect", "bigquery") + .getOrCreate() .createDataFrame(data, schema) .groupBy(F.col("age")) .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) - .sql(dialect="bigquery") + .sql() ) result = None @@ -166,11 +173,14 @@ schema = types.StructType([ ]) sql_statements = ( - SparkSession() + SparkSession + .builder + .config("sqlframe.dialect", "snowflake") + .getOrCreate() .createDataFrame(data, schema) .groupBy(F.col("age")) .agg(F.countDistinct(F.col("lname")).alias("num_employees")) - .sql(dialect="snowflake") + .sql() ) try: @@ -210,7 +220,7 @@ sql_statements = ( .createDataFrame(data, schema) .groupBy(F.col("age")) .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) - .sql(dialect="spark") + .sql() ) pyspark = PySparkSession.builder.master("local[*]").getOrCreate() diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index fcfd71e..3acf494 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -5,7 +5,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp from sqlglot.dataframe.sql.types import DataType -from sqlglot.dialects import Spark from sqlglot.helper import flatten, is_iterable if t.TYPE_CHECKING: @@ -15,19 +14,20 @@ if t.TYPE_CHECKING: class Column: def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + from sqlglot.dataframe.sql.session import SparkSession + if isinstance(expression, Column): expression = expression.expression # type: ignore elif expression is None or not isinstance(expression, (str, exp.Expression)): expression = self._lit(expression).expression # type: ignore - - expression = sqlglot.maybe_parse(expression, dialect="spark") + elif not isinstance(expression, exp.Column): + expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( + SparkSession().dialect.normalize_identifier, copy=False + ) if expression is None: raise ValueError(f"Could not parse {expression}") - if isinstance(expression, exp.Column): - expression.transform(Spark.normalize_identifier, copy=False) - - self.expression: exp.Expression = expression + self.expression: exp.Expression = expression # type: ignore def __repr__(self): return repr(self.expression) @@ -207,7 +207,9 @@ class Column: return Column(expression) def sql(self, **kwargs) -> str: - return self.expression.sql(**{"dialect": "spark", **kwargs}) + from sqlglot.dataframe.sql.session import SparkSession + + return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) def alias(self, name: str) -> Column: new_expression = exp.alias_(self.column_expression, name) @@ -264,9 +266,11 @@ class Column: Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string """ + from sqlglot.dataframe.sql.session import SparkSession + if isinstance(dataType, DataType): dataType = dataType.simpleString() - return Column(exp.cast(self.column_expression, dataType, dialect="spark")) + return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) def startswith(self, value: t.Union[str, Column]) -> Column: value = self._lit(value) if not isinstance(value, Column) else value diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 64cceea..f515608 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -1,12 +1,13 @@ from __future__ import annotations import functools +import logging import typing as t import zlib from copy import copy import sqlglot -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.group import GroupedData @@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window from sqlglot.helper import ensure_list, object_to_dict, seq_get from sqlglot.optimizer import optimize as optimize_func +from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ( @@ -27,7 +29,9 @@ if t.TYPE_CHECKING: OutputExpressionContainer, ) from sqlglot.dataframe.sql.session import SparkSession + from sqlglot.dialects.dialect import DialectType +logger = logging.getLogger("sqlglot") JOIN_HINTS = { "BROADCAST", @@ -264,7 +268,9 @@ class DataFrame: @classmethod def _create_hash_from_expression(cls, expression: exp.Expression) -> str: - value = expression.sql(dialect="spark").encode("utf-8") + from sqlglot.dataframe.sql.session import SparkSession + + value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") return f"t{zlib.crc32(value)}"[:6] def _get_select_expressions( @@ -291,7 +297,15 @@ class DataFrame: select_expressions.append(expression_select_pair) # type: ignore return select_expressions - def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: + def sql( + self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs + ) -> t.List[str]: + from sqlglot.dataframe.sql.session import SparkSession + + if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: + logger.warning( + f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." + ) df = self._resolve_pending_hints() select_expressions = df._get_select_expressions() output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] @@ -299,7 +313,10 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = t.cast(exp.Select, optimize_func(select_expression)) + quote_identifiers(select_expression) + select_expression = t.cast( + exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) + ) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: @@ -313,10 +330,12 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.sql("spark") + expression.alias_or_name: expression.type.sql( + dialect=SparkSession().dialect + ) for expression in select_expression.expressions }, - dialect="spark", + dialect=SparkSession().dialect, ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ @@ -345,7 +364,8 @@ class DataFrame: output_expressions.append(expression) return [ - expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) + for expression in output_expressions ] def copy(self, **kwargs) -> DataFrame: diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 4002cfe..d0ae50c 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -368,9 +368,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - if ignorenulls is not None: - return Column.invoke_anonymous_function(col, "FIRST", ignorenulls) - return Column.invoke_anonymous_function(col, "FIRST") + return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls) def grouping_id(*cols: ColumnOrName) -> Column: @@ -394,9 +392,7 @@ def isnull(col: ColumnOrName) -> Column: def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - if ignorenulls is not None: - return Column.invoke_anonymous_function(col, "LAST", ignorenulls) - return Column.invoke_anonymous_function(col, "LAST") + return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls) def monotonically_increasing_id() -> Column: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 4eec782..f68bacb 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import expressions as exp from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join -from sqlglot.dialects import Spark from sqlglot.helper import ensure_list NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) @@ -20,7 +19,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ for expression in expressions: identifiers = expression.find_all(exp.Identifier) for identifier in identifiers: - Spark.normalize_identifier(identifier) + identifier.transform(spark.dialect.normalize_identifier) replace_alias_name_with_cte_name(spark, expression_context, identifier) replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 9d87d4a..0804486 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.dialects import Spark from sqlglot.helper import object_to_dict if t.TYPE_CHECKING: @@ -18,15 +17,25 @@ class DataFrameReader: def table(self, tableName: str) -> DataFrame: from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.session import SparkSession - sqlglot.schema.add_table(tableName, dialect="spark") + sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) return DataFrame( self.spark, exp.Select() - .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) + .from_( + exp.to_table(tableName, dialect=SparkSession().dialect).transform( + SparkSession().dialect.normalize_identifier + ) + ) .select( - *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) + *( + column + for column in sqlglot.schema.column_names( + tableName, dialect=SparkSession().dialect + ) + ) ), ) @@ -63,6 +72,8 @@ class DataFrameWriter: return self.copy(by_name=True) def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: + from sqlglot.dataframe.sql.session import SparkSession + output_expression_container = exp.Insert( **{ "this": exp.to_table(tableName), @@ -71,7 +82,9 @@ class DataFrameWriter: ) df = self._df.copy(output_expression_container=output_expression_container) if self._by_name: - columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") + columns = sqlglot.schema.column_names( + tableName, only_visible=True, dialect=SparkSession().dialect + ) df = df._convert_leaf_to_cte().select(*columns) return self.copy(_df=df) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index b883359..531ee17 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -5,31 +5,35 @@ import uuid from collections import defaultdict import sqlglot -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.readwriter import DataFrameReader from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input +from sqlglot.helper import classproperty if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput class SparkSession: - known_ids: t.ClassVar[t.Set[str]] = set() - known_branch_ids: t.ClassVar[t.Set[str]] = set() - known_sequence_ids: t.ClassVar[t.Set[str]] = set() - name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) + DEFAULT_DIALECT = "spark" + _instance = None def __init__(self): - self.incrementing_id = 1 - - def __getattr__(self, name: str) -> SparkSession: - return self - - def __call__(self, *args, **kwargs) -> SparkSession: - return self + if not hasattr(self, "known_ids"): + self.known_ids = set() + self.known_branch_ids = set() + self.known_sequence_ids = set() + self.name_to_sequence_id_mapping = defaultdict(list) + self.incrementing_id = 1 + self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() + + def __new__(cls, *args, **kwargs) -> SparkSession: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance @property def read(self) -> DataFrameReader: @@ -101,7 +105,7 @@ class SparkSession: return DataFrame(self, sel_expression) def sql(self, sqlQuery: str) -> DataFrame: - expression = sqlglot.parse_one(sqlQuery, read="spark") + expression = sqlglot.parse_one(sqlQuery, read=self.dialect) if isinstance(expression, exp.Select): df = DataFrame(self, expression) df = df._convert_leaf_to_cte() @@ -149,3 +153,38 @@ class SparkSession: def _add_alias_to_mapping(self, name: str, sequence_id: str): self.name_to_sequence_id_mapping[name].append(sequence_id) + + class Builder: + SQLFRAME_DIALECT_KEY = "sqlframe.dialect" + + def __init__(self): + self.dialect = "spark" + + def __getattr__(self, item) -> SparkSession.Builder: + return self + + def __call__(self, *args, **kwargs): + return self + + def config( + self, + key: t.Optional[str] = None, + value: t.Optional[t.Any] = None, + *, + map: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, + ) -> SparkSession.Builder: + if key == self.SQLFRAME_DIALECT_KEY: + self.dialect = value + elif map and self.SQLFRAME_DIALECT_KEY in map: + self.dialect = map[self.SQLFRAME_DIALECT_KEY] + return self + + def getOrCreate(self) -> SparkSession: + spark = SparkSession() + spark.dialect = Dialect.get_or_raise(self.dialect)() + return spark + + @classproperty + def builder(cls) -> Builder: + return cls.Builder() diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index c54c07e..c1d913f 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -48,7 +48,9 @@ class WindowSpec: return WindowSpec(self.expression.copy()) def sql(self, **kwargs) -> str: - return self.expression.sql(dialect="spark", **kwargs) + from sqlglot.dataframe.sql.session import SparkSession + + return self.expression.sql(dialect=SparkSession().dialect, **kwargs) def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: from sqlglot.dataframe.sql.column import Column diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 71977dd..d763ed0 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + json_keyvalue_comma_sql, max_or_greatest, min_or_least, no_ilike_sql, @@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot") def _date_add_sql( data_type: str, kind: str -) -> t.Callable[[generator.Generator, exp.Expression], str]: - def func(self, expression): +) -> t.Callable[[BigQuery.Generator, exp.Expression], str]: + def func(self: BigQuery.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") @@ -40,7 +41,7 @@ def _date_add_sql( return func -def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: +def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str: if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) @@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)])) -def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: +def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): this = f"{this.this} <{self.expressions(this)}>" @@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope return f"RETURNS {this}" -def _create_sql(self: generator.Generator, expression: exp.Create) -> str: +def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) @@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: These are added by the optimizer's qualify_column step. """ - from sqlglot.optimizer.scope import Scope + from sqlglot.optimizer.scope import find_all_in_scope if isinstance(expression, exp.Select): - for unnest in expression.find_all(exp.Unnest): - if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias: - for column in Scope(expression).find_all(exp.Column): - if column.table == unnest.alias: - column.set("table", None) + unnest_aliases = { + unnest.alias + for unnest in find_all_in_scope(expression, exp.Unnest) + if isinstance(unnest.parent, (exp.From, exp.Join)) + } + if unnest_aliases: + for column in expression.find_all(exp.Column): + if column.table in unnest_aliases: + column.set("table", None) + elif column.db in unnest_aliases: + column.set("db", None) return expression @@ -261,6 +268,7 @@ class BigQuery(Dialect): "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } KEYWORDS.pop("DIV") @@ -270,6 +278,8 @@ class BigQuery(Dialect): LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": _parse_date, @@ -299,6 +309,8 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), + "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), + "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), "SPLIT": lambda args: exp.Split( # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split this=seq_get(args, 0), @@ -346,7 +358,7 @@ class BigQuery(Dialect): } def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: - this = super()._parse_table_part(schema=schema) + this = super()._parse_table_part(schema=schema) or self._parse_number() # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names if isinstance(this, exp.Identifier): @@ -356,6 +368,17 @@ class BigQuery(Dialect): table_name += f"-{self._prev.text}" this = exp.Identifier(this=table_name, quoted=this.args.get("quoted")) + elif isinstance(this, exp.Literal): + table_name = this.name + + if ( + self._curr + and self._prev.end == self._curr.start - 1 + and self._parse_var(any_token=True) + ): + table_name += self._prev.text + + this = exp.Identifier(this=table_name, quoted=True) return this @@ -374,6 +397,27 @@ class BigQuery(Dialect): return table + def _parse_json_object(self) -> exp.JSONObject: + json_object = super()._parse_json_object() + array_kv_pair = seq_get(json_object.expressions, 0) + + # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation + # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2 + if ( + array_kv_pair + and isinstance(array_kv_pair.this, exp.Array) + and isinstance(array_kv_pair.expression, exp.Array) + ): + keys = array_kv_pair.this.expressions + values = array_kv_pair.expression.expressions + + json_object.set( + "expressions", + [exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)], + ) + + return json_object + class Generator(generator.Generator): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False @@ -383,6 +427,7 @@ class BigQuery(Dialect): LIMIT_FETCH = "LIMIT" RENAME_TABLE_WITH_DB = False ESCAPE_LINE_BREAK = True + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -405,6 +450,7 @@ class BigQuery(Dialect): exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), + exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), @@ -428,6 +474,9 @@ class BigQuery(Dialect): _alias_ordered_group, ] ), + exp.SHA2: lambda self, e: self.func( + f"SHA256" if e.text("length") == "256" else "SHA512", e.this + ), exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -591,6 +640,13 @@ class BigQuery(Dialect): return super().attimezone_sql(expression) + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals + if expression.is_type("json"): + return f"JSON {self.sql(expression, 'this')}" + + return super().cast_sql(expression, safe_prefix=safe_prefix) + def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") @@ -630,3 +686,9 @@ class BigQuery(Dialect): def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("OPTIONS")) + + def version_sql(self, expression: exp.Version) -> str: + if expression.name == "TIMESTAMP": + expression = expression.copy() + expression.set("this", "SYSTEM_TIME") + return super().version_sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index cfde5fd..a38a239 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.errors import ParseError +from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import Token, TokenType @@ -63,9 +64,23 @@ class ClickHouse(Dialect): } class Parser(parser.Parser): + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, + "DATE_ADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, @@ -147,7 +162,7 @@ class ClickHouse(Dialect): this = self._parse_id_var() self._match(TokenType.COLON) - kind = self._parse_types(check_func=False) or ( + kind = self._parse_types(check_func=False, allow_identifiers=False) or ( self._match_text_seq("IDENTIFIER") and "Identifier" ) @@ -249,7 +264,7 @@ class ClickHouse(Dialect): def _parse_func_params( self, this: t.Optional[exp.Func] = None - ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + ) -> t.Optional[t.List[exp.Expression]]: if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): return self._parse_csv(self._parse_lambda) @@ -267,9 +282,7 @@ class ClickHouse(Dialect): return self.expression(exp.Quantile, this=params[0], quantile=this) return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5)) - def _parse_wrapped_id_vars( - self, optional: bool = False - ) -> t.List[t.Optional[exp.Expression]]: + def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]: return super()._parse_wrapped_id_vars(optional=True) def _parse_primary_key( @@ -292,9 +305,22 @@ class ClickHouse(Dialect): class Generator(generator.Generator): QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") + NVL2_SUPPORTED = False + + STRING_TYPE_MAPPING = { + exp.DataType.Type.CHAR: "String", + exp.DataType.Type.LONGBLOB: "String", + exp.DataType.Type.LONGTEXT: "String", + exp.DataType.Type.MEDIUMBLOB: "String", + exp.DataType.Type.MEDIUMTEXT: "String", + exp.DataType.Type.TEXT: "String", + exp.DataType.Type.VARBINARY: "String", + exp.DataType.Type.VARCHAR: "String", + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + **STRING_TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.DATETIME64: "DateTime64", @@ -328,6 +354,12 @@ class ClickHouse(Dialect): exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -364,6 +396,16 @@ class ClickHouse(Dialect): "NAMED COLLECTION", } + def datatype_sql(self, expression: exp.DataType) -> str: + # String is the standard ClickHouse type, every other variant is just an alias. + # Additionally, any supplied length parameter will be ignored. + # + # https://clickhouse.com/docs/en/sql-reference/data-types/string + if expression.this in self.STRING_TYPE_MAPPING: + return "String" + + return super().datatype_sql(expression) + def safeconcat_sql(self, expression: exp.SafeConcat) -> str: # Clickhouse errors out if we try to cast a NULL value to TEXT expression = expression.copy() diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2149aca..6ec0487 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, transforms -from sqlglot.dialects.dialect import parse_date_delta +from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql from sqlglot.dialects.spark import Spark from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql from sqlglot.tokens import TokenType @@ -28,6 +28,19 @@ class Databricks(Spark): **Spark.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, + exp.DatetimeAdd: lambda self, e: self.func( + "TIMESTAMPADD", e.text("unit"), e.expression, e.this + ), + exp.DatetimeSub: lambda self, e: self.func( + "TIMESTAMPADD", + e.text("unit"), + exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)), + e.this, + ), + exp.DatetimeDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + ), + exp.DatetimeTrunc: timestamptrunc_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), exp.Select: transforms.preprocess( [ diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 132496f..1bfbfef 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -109,8 +109,7 @@ class _Dialect(type): for k, v in vars(klass).items() if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") }, - "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], - "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], + "TOKENIZER_CLASS": klass.tokenizer_class, } if enum not in ("", "bigquery"): @@ -345,7 +344,7 @@ def arrow_json_extract_scalar_sql( def inline_array_sql(self: Generator, expression: exp.Array) -> str: - return f"[{self.expressions(expression)}]" + return f"[{self.expressions(expression, flat=True)}]" def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: @@ -415,9 +414,9 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: - this = self.sql(expression, "this") - struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True)) - return f"{this}.{struct_key}" + return ( + f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" + ) def var_map_sql( @@ -722,3 +721,12 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: # Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) + + +def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: + return self.func("MAX", expression.this) + + +# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon +def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str: + return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 160c23c..4b8919c 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -37,7 +37,6 @@ class Doris(MySQL): **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), - exp.Coalesce: rename_func("NVL"), exp.CurrentTimestamp: lambda *_: "NOW()", exp.DateTrunc: lambda self, e: self.func( "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 1b2681d..c811c86 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -16,8 +16,8 @@ from sqlglot.dialects.dialect import ( ) -def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -25,7 +25,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e return func -def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: +def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Drill.DATE_FORMAT: @@ -73,7 +73,6 @@ class Drill(Dialect): } class Tokenizer(tokens.Tokenizer): - QUOTES = ["'"] IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] ENCODE = "utf-8" @@ -81,6 +80,7 @@ class Drill(Dialect): class Parser(parser.Parser): STRICT_CAST = False CONCAT_NULL_OUTPUTS_STRING = True + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -95,6 +95,7 @@ class Drill(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + NVL2_SUPPORTED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 8253b52..684e35e 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, encode_decode_sql, format_time_lambda, + inline_array_sql, no_comment_column_constraint_sql, no_properties_sql, no_safe_divide_sql, @@ -30,13 +31,13 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: +def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" -def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" op = "+" if isinstance(expression, exp.DateAdd) else "-" @@ -44,7 +45,7 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat # BigQuery -> DuckDB conversion for the DATE function -def _date_sql(self: generator.Generator, expression: exp.Date) -> str: +def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str: result = f"CAST({self.sql(expression, 'this')} AS DATE)" zone = self.sql(expression, "zone") @@ -58,13 +59,13 @@ def _date_sql(self: generator.Generator, expression: exp.Date) -> str: return result -def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: +def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") return f"ARRAY_SORT({self.sql(expression, 'this')})" -def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str: +def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str: this = self.sql(expression, "this") if expression.args.get("asc") == exp.false(): return f"ARRAY_REVERSE_SORT({this})" @@ -79,14 +80,14 @@ def _parse_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) -def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: +def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: args = [ f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions ] return f"{{{', '.join(args)}}}" -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" @@ -97,7 +98,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: +def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str: sql = self.func("TO_JSON", expression.this, expression.args.get("options")) return f"CAST({sql} AS TEXT)" @@ -134,6 +135,7 @@ class DuckDB(Dialect): class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True + SUPPORTS_USER_DEFINED_TYPES = False BITWISE = { **parser.Parser.BITWISE, @@ -183,18 +185,12 @@ class DuckDB(Dialect): ), } - TYPE_TOKENS = { - *parser.Parser.TYPE_TOKENS, - TokenType.UBIGINT, - TokenType.UINT, - TokenType.USMALLINT, - TokenType.UTINYINT, - } - def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: - this = super()._parse_types(check_func=check_func, schema=schema) + this = super()._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) # DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3) # See: https://duckdb.org/docs/sql/data_types/numeric @@ -207,6 +203,9 @@ class DuckDB(Dialect): return this + def _parse_struct_types(self) -> t.Optional[exp.Expression]: + return self._parse_field_def() + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: if len(aggregations) == 1: return super()._pivot_column_names(aggregations) @@ -219,13 +218,14 @@ class DuckDB(Dialect): LIMIT_FETCH = "LIMIT" STRUCT_DELIMITER = ("(", ")") RENAME_TABLE_WITH_DB = False + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) if e.expressions and e.expressions[0].find(exp.Select) - else rename_func("LIST_VALUE")(self, e), + else inline_array_sql(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 584acc6..8b17c06 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -50,7 +50,7 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) @@ -69,7 +69,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS return self.func(func, expression.this, modified_increment) -def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: +def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = TIME_DIFF_FACTOR.get(unit) @@ -87,7 +87,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: return f"{diff_sql}{multiplier_sql}" -def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: +def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: this = expression.this if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string: # Since FROM_JSON requires a nested type, we always wrap the json string with @@ -103,21 +103,21 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s return self.func("TO_JSON", this, expression.args.get("options")) -def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: +def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" -def _property_sql(self: generator.Generator, expression: exp.Property) -> str: +def _property_sql(self: Hive.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str: +def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression)) -def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: +def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): @@ -125,7 +125,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st return f"CAST({this} AS DATE)" -def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str: +def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): @@ -133,13 +133,13 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st return f"CAST({this} AS TIMESTAMP)" -def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: +def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) return f"DATE_FORMAT({this}, {time_format})" -def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: +def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): @@ -206,6 +206,8 @@ class Hive(Dialect): "MSCK REPAIR": TokenType.COMMAND, "REFRESH": TokenType.COMMAND, "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, + "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT, + "VERSION AS OF": TokenType.VERSION_SNAPSHOT, } NUMERIC_LITERALS = { @@ -220,6 +222,7 @@ class Hive(Dialect): class Parser(parser.Parser): LOG_DEFAULTS_TO_LN = True STRICT_CAST = False + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -257,6 +260,11 @@ class Hive(Dialect): ), "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, + "STR_TO_MAP": lambda args: exp.StrToMap( + this=seq_get(args, 0), + pair_delim=seq_get(args, 1) or exp.Literal.string(","), + key_value_delim=seq_get(args, 2) or exp.Literal.string(":"), + ), "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, "UNBASE64": exp.FromBase64.from_arg_list, @@ -313,7 +321,7 @@ class Hive(Dialect): ) def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: """ Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to @@ -333,7 +341,9 @@ class Hive(Dialect): Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html """ - this = super()._parse_types(check_func=check_func, schema=schema) + this = super()._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) if this and not schema: return this.transform( @@ -345,6 +355,16 @@ class Hive(Dialect): return this + def _parse_partition_and_order( + self, + ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: + return ( + self._parse_csv(self._parse_conjunction) + if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) + else [], + super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)), + ) + class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False @@ -354,6 +374,7 @@ class Hive(Dialect): QUERY_HINTS = False INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False + NVL2_SUPPORTED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -376,6 +397,7 @@ class Hive(Dialect): ] ), exp.Property: _property_sql, + exp.AnyValue: rename_func("FIRST"), exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), @@ -402,6 +424,9 @@ class Hive(Dialect): exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), + exp.NotNullColumnConstraint: lambda self, e: "" + if e.args.get("allow_null") + else "NOT NULL", exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), @@ -472,7 +497,7 @@ class Hive(Dialect): elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) elif expression.is_type("float"): - size_expression = expression.find(exp.DataTypeSize) + size_expression = expression.find(exp.DataTypeParam) if size_expression: size = int(size_expression.name) expression = ( @@ -480,3 +505,7 @@ class Hive(Dialect): ) return super().datatype_sql(expression) + + def version_sql(self, expression: exp.Version) -> str: + sql = super().version_sql(expression) + return sql.replace("FOR ", "", 1) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 9ab4ce8..f9249eb 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, datestrtodate_sql, format_time_lambda, + json_keyvalue_comma_sql, locate_to_strposition, max_or_greatest, min_or_least, @@ -32,7 +33,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex return _parse -def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str: +def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit") @@ -63,12 +64,12 @@ def _str_to_date(args: t.List) -> exp.StrToDate: return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: +def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" -def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: +def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") @@ -83,8 +84,8 @@ def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -93,6 +94,9 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e class MySQL(Dialect): + # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html + IDENTIFIERS_CAN_START_WITH_DIGIT = True + TIME_FORMAT = "'%Y-%m-%d %T'" DPIPE_IS_STRING_CONCAT = False @@ -129,6 +133,7 @@ class MySQL(Dialect): "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "MEDIUMINT": TokenType.MEDIUMINT, "MEMBER OF": TokenType.MEMBER_OF, "SEPARATOR": TokenType.SEPARATOR, "START": TokenType.BEGIN, @@ -136,6 +141,7 @@ class MySQL(Dialect): "SIGNED INTEGER": TokenType.BIGINT, "UNSIGNED": TokenType.UBIGINT, "UNSIGNED INTEGER": TokenType.UBIGINT, + "YEAR": TokenType.YEAR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -185,6 +191,8 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): + SUPPORTS_USER_DEFINED_TYPES = False + FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, TokenType.DATABASE, @@ -492,6 +500,17 @@ class MySQL(Dialect): return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") + def _parse_type(self) -> t.Optional[exp.Expression]: + # mysql binary is special and can work anywhere, even in order by operations + # it operates like a no paren func + if self._match(TokenType.BINARY, advance=False): + data_type = self._parse_types(check_func=True, allow_identifiers=False) + + if isinstance(data_type, exp.DataType): + return self.expression(exp.Cast, this=self._parse_column(), to=data_type) + + return super()._parse_type() + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -500,6 +519,7 @@ class MySQL(Dialect): DUPLICATE_KEY_UPDATE_WITH_SET = False QUERY_HINT_SEP = " " VALUES_AS_TABLE = False + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -515,6 +535,7 @@ class MySQL(Dialect): exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), @@ -524,6 +545,7 @@ class MySQL(Dialect): exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, + exp.Stuff: rename_func("INSERT"), exp.TableSample: no_tablesample_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 1f63e9f..279ed31 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -8,7 +8,7 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: +def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable: this = self._parse_string() passing = None @@ -22,7 +22,7 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") if self._match_text_seq("COLUMNS"): - columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True))) + columns = self._parse_csv(self._parse_field_def) return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref) @@ -78,6 +78,10 @@ class Oracle(Dialect): ) } + # SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT .. + # Reference: https://stackoverflow.com/a/336455 + DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE} + def _parse_column(self) -> t.Optional[exp.Expression]: column = super()._parse_column() if column: @@ -129,7 +133,6 @@ class Oracle(Dialect): ), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.ILike: no_ilike_sql, - exp.Coalesce: rename_func("NVL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), @@ -162,7 +165,7 @@ class Oracle(Dialect): return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}" class Tokenizer(tokens.Tokenizer): - VAR_SINGLE_TOKENS = {"@"} + VAR_SINGLE_TOKENS = {"@", "$", "#"} KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 73ca4e5..c26e121 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, datestrtodate_sql, @@ -39,8 +40,8 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str: expression = expression.copy() this = self.sql(expression, "this") @@ -56,7 +57,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e return func -def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: +def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) @@ -82,7 +83,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: return f"CAST({unit} AS BIGINT)" -def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: +def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str: this = self.sql(expression, "this") start = self.sql(expression, "start") length = self.sql(expression, "length") @@ -93,7 +94,7 @@ def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: return f"SUBSTRING({this}{from_part}{for_part})" -def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: +def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -107,7 +108,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s return f"STRING_AGG({self.format_args(this, separator)}{order})" -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) @@ -254,6 +255,7 @@ class Postgres(Dialect): "~~*": TokenType.ILIKE, "~*": TokenType.IRLIKE, "~": TokenType.RLIKE, + "@@": TokenType.DAT, "@>": TokenType.AT_GT, "<@": TokenType.LT_AT, "BEGIN": TokenType.COMMAND, @@ -273,6 +275,18 @@ class Postgres(Dialect): "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, + "OID": TokenType.OBJECT_IDENTIFIER, + "REGCLASS": TokenType.OBJECT_IDENTIFIER, + "REGCOLLATION": TokenType.OBJECT_IDENTIFIER, + "REGCONFIG": TokenType.OBJECT_IDENTIFIER, + "REGDICTIONARY": TokenType.OBJECT_IDENTIFIER, + "REGNAMESPACE": TokenType.OBJECT_IDENTIFIER, + "REGOPER": TokenType.OBJECT_IDENTIFIER, + "REGOPERATOR": TokenType.OBJECT_IDENTIFIER, + "REGPROC": TokenType.OBJECT_IDENTIFIER, + "REGPROCEDURE": TokenType.OBJECT_IDENTIFIER, + "REGROLE": TokenType.OBJECT_IDENTIFIER, + "REGTYPE": TokenType.OBJECT_IDENTIFIER, } SINGLE_TOKENS = { @@ -312,6 +326,9 @@ class Postgres(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), + TokenType.DAT: lambda self, this: self.expression( + exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this] + ), TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), } @@ -343,6 +360,7 @@ class Postgres(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + NVL2_SUPPORTED = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { @@ -357,6 +375,8 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, + exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.Explode: rename_func("UNNEST"), @@ -416,3 +436,9 @@ class Postgres(Dialect): expression.set("this", exp.paren(expression.this, copy=False)) return super().bracket_sql(expression) + + def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: + this = self.sql(expression, "this") + expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions] + sql = " OR ".join(expressions) + return f"({sql})" if len(expressions) > 1 else sql diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 078da0b..4b54e95 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str: +def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str: accuracy = expression.args.get("accuracy") accuracy = ", " + self.sql(accuracy) if accuracy else "" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: +def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): expression = expression.copy() return self.sql( @@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) - return self.lateral_sql(expression) -def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: +def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str: regex = r"(\w)(\w*)" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: +def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: @@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: return self.func("ARRAY_SORT", expression.this, comparator) -def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: +def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: if isinstance(expression.parent, exp.Property): columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" @@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: return self.schema_sql(expression) -def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str: +def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" def _str_to_time_sql( - self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate + self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate ) -> str: return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: +def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") -def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: +def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): @@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression: return expression +def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str: + """ + Trino doesn't support FIRST / LAST as functions, but they're valid in the context + of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases + they're converted into an ARBITRARY call. + + Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions + """ + if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize): + return self.function_fallback_sql(expression) + + return rename_func("ARBITRARY")(self, expression) + + class Presto(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_last" @@ -178,6 +192,7 @@ class Presto(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "ARBITRARY": exp.AnyValue.from_arg_list, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, "BITWISE_AND": binary_from_function(exp.BitwiseAnd), @@ -205,7 +220,14 @@ class Presto(Dialect): "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) ), + "REGEXP_REPLACE": lambda args: exp.RegexpReplace( + this=seq_get(args, 0), + expression=seq_get(args, 1), + replacement=seq_get(args, 2) or exp.Literal.string(""), + ), + "ROW": exp.Struct.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, + "SPLIT_TO_MAP": exp.StrToMap.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), @@ -225,6 +247,7 @@ class Presto(Dialect): QUERY_HINTS = False IS_BOOL_ALLOWED = False TZ_TO_WITH_TIME_ZONE = True + NVL2_SUPPORTED = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -242,10 +265,13 @@ class Presto(Dialect): exp.DataType.Type.TIMETZ: "TIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.STRUCT: "ROW", + exp.DataType.Type.DATETIME: "TIMESTAMP", + exp.DataType.Type.DATETIME64: "TIMESTAMP", } TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: rename_func("ARBITRARY"), exp.ApproxDistinct: _approx_distinct_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", @@ -268,15 +294,23 @@ class Presto(Dialect): ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", + exp.DateSub: lambda self, e: self.func( + "DATE_ADD", + exp.Literal.string(e.text("unit") or "day"), + e.expression * -1, + e.this, + ), exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, + exp.Last: _first_last_sql, exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -301,8 +335,10 @@ class Presto(Dialect): exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.Struct: rename_func("ROW"), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 30731e1..351c5df 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -13,7 +13,7 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: +def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: return f'{self.sql(expression, "this")}."{expression.expression.name}"' @@ -37,6 +37,8 @@ class Redshift(Postgres): } class Parser(Postgres.Parser): + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **Postgres.Parser.FUNCTIONS, "ADD_MONTHS": lambda args: exp.DateAdd( @@ -55,9 +57,11 @@ class Redshift(Postgres): } def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: - this = super()._parse_types(check_func=check_func, schema=schema) + this = super()._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) if ( isinstance(this, exp.DataType) @@ -100,6 +104,7 @@ class Redshift(Postgres): QUERY_HINTS = False VALUES_AS_TABLE = False TZ_TO_WITH_TIME_ZONE = True + NVL2_SUPPORTED = True TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -142,6 +147,9 @@ class Redshift(Postgres): # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) TRANSFORMS.pop(exp.Pow) + # Redshift supports ANY_VALUE(..) + TRANSFORMS.pop(exp.AnyValue) + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} def with_properties(self, properties: exp.Properties) -> str: diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9733a85..8d8183c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) -def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: +def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: +def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: this = self._parse_var() or self._parse_type() if not this: @@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If: return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return "ARRAY" elif expression.is_type("map"): @@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) +def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: + flag = expression.text("flag") + + if "i" not in flag: + flag += "i" + + return self.func( + "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag) + ) + + def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) @@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: return regexp_replace +def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]: + def _parse(self: Snowflake.Parser) -> exp.Show: + return self._parse_show_snowflake(*args, **kwargs) + + return _parse + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax RESOLVES_IDENTIFIERS_AS_UPPERCASE = True @@ -216,6 +234,7 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -230,6 +249,7 @@ class Snowflake(Dialect): "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, "REGEXP_REPLACE": _parse_regexp_replace, @@ -250,11 +270,6 @@ class Snowflake(Dialect): } FUNCTION_PARSERS.pop("TRIM") - FUNC_TOKENS = { - *parser.Parser.FUNC_TOKENS, - TokenType.TABLE, - } - COLUMN_OPERATORS = { **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( @@ -281,6 +296,16 @@ class Snowflake(Dialect): ), } + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + } + + SHOW_PARSERS = { + "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + } + def _parse_id_var( self, any_token: bool = True, @@ -296,8 +321,24 @@ class Snowflake(Dialect): return super()._parse_id_var(any_token=any_token, tokens=tokens) + def _parse_show_snowflake(self, this: str) -> exp.Show: + scope = None + scope_kind = None + + if self._match(TokenType.IN): + if self._match_text_seq("ACCOUNT"): + scope_kind = "ACCOUNT" + elif self._match_set(self.DB_CREATABLES): + scope_kind = self._prev.text + if self._curr: + scope = self._parse_table() + elif self._curr: + scope_kind = "TABLE" + scope = self._parse_table() + + return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + class Tokenizer(tokens.Tokenizer): - QUOTES = ["'"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] RAW_STRINGS = ["$$"] @@ -331,6 +372,8 @@ class Snowflake(Dialect): VAR_SINGLE_TOKENS = {"$"} + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} + class Generator(generator.Generator): PARAMETER_TOKEN = "$" MATCHED_BY_SOURCE = False @@ -355,6 +398,7 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.Extract: rename_func("DATE_PART"), + exp.GroupConcat: rename_func("LISTAGG"), exp.If: rename_func("IFF"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), @@ -362,6 +406,7 @@ class Snowflake(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.RegexpILike: _regexpilike_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), @@ -373,6 +418,7 @@ class Snowflake(Dialect): "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), ), + exp.Stuff: rename_func("INSERT"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( @@ -403,6 +449,16 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def show_sql(self, expression: exp.Show) -> str: + scope = self.sql(expression, "scope") + scope = f" {scope}" if scope else "" + + scope_kind = self.sql(expression, "scope_kind") + if scope_kind: + scope_kind = f" IN {scope_kind}" + + return f"SHOW {expression.name}{scope_kind}{scope}" + def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to # generate default values as necessary to ensure the transpilation is correct @@ -436,7 +492,9 @@ class Snowflake(Dialect): kind_value = expression.args.get("kind") or "TABLE" kind = f" {kind_value}" if kind_value else "" this = f" {self.sql(expression, 'this')}" - return f"DESCRIBE{kind}{this}" + expressions = self.expressions(expression, flat=True) + expressions = f" {expressions}" if expressions else "" + return f"DESCRIBE{kind}{this}{expressions}" def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 7c8982b..a4435f6 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -38,9 +38,15 @@ class Spark(Spark2): class Parser(Spark2.Parser): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, + "ANY_VALUE": lambda args: exp.AnyValue( + this=seq_get(args, 0), ignore_nulls=seq_get(args, 1) + ), "DATEDIFF": _parse_datediff, } + FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy() + FUNCTION_PARSERS.pop("ANY_VALUE") + class Generator(Spark2.Generator): TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, @@ -56,9 +62,13 @@ class Spark(Spark2): "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this ), } + TRANSFORMS.pop(exp.AnyValue) TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) + def anyvalue_sql(self, expression: exp.AnyValue) -> str: + return self.function_fallback_sql(expression) + def datediff_sql(self, expression: exp.DateDiff) -> str: unit = self.sql(expression, "unit") end = self.sql(expression, "this") diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index ceb48f8..4489b6b 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -15,7 +15,7 @@ from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self: Hive.Generator, e: exp.Create) -> str: +def _create_sql(self: Spark2.Generator, e: exp.Create) -> str: kind = e.args["kind"] properties = e.args.get("properties") @@ -31,17 +31,21 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str: return create_with_partitions_sql(self, e) -def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: - keys = self.sql(expression.args["keys"]) - values = self.sql(expression.args["values"]) - return f"MAP_FROM_ARRAYS({keys}, {values})" +def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: + keys = expression.args.get("keys") + values = expression.args.get("values") + + if not keys or not values: + return "MAP()" + + return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})" def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) -def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: +def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.DATE_FORMAT: @@ -49,7 +53,7 @@ def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: return f"TO_DATE({this}, {time_format})" -def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: +def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale is None: @@ -110,6 +114,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str: + if expression.expression.args.get("with"): + expression = expression.copy() + expression.set("with", expression.expression.args.pop("with")) + return self.insert_sql(expression) + + class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { @@ -169,10 +180,7 @@ class Spark2(Hive): class Generator(Hive.Generator): QUERY_HINTS = True - - TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, - } + NVL2_SUPPORTED = True PROPERTIES_LOCATION = { **Hive.Generator.PROPERTIES_LOCATION, @@ -197,6 +205,7 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), + exp.Insert: _insert_sql, exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 90b774e..7bfdf1c 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, concat_to_dpipe_sql, @@ -18,7 +19,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str: modifier = expression.expression modifier = modifier.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") @@ -78,6 +79,7 @@ class SQLite(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + NVL2_SUPPORTED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -103,6 +105,7 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, exp.Concat: concat_to_dpipe_sql, exp.CountIf: count_if_to_sum, exp.Create: transforms.preprocess([_transform_create]), diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 2be1a62..163cc13 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -95,6 +95,9 @@ class Teradata(Dialect): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, + TokenType.DATABASE: lambda self: self.expression( + exp.Use, this=self._parse_table(schema=False) + ), TokenType.REPLACE: lambda self: self._parse_create(), } @@ -165,6 +168,7 @@ class Teradata(Dialect): exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index af0f78d..0c953a1 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -13,3 +13,6 @@ class Trino(Presto): class Tokenizer(Presto.Tokenizer): HEX_STRINGS = [("X'", "'")] + + class Parser(Presto.Parser): + SUPPORTS_USER_DEFINED_TYPES = False diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 131307f..b26f499 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -7,6 +7,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, max_or_greatest, min_or_least, parse_date_delta, @@ -79,22 +80,23 @@ def _format_time_lambda( def _parse_format(args: t.List) -> exp.Expression: - assert len(args) == 2 + this = seq_get(args, 0) + fmt = seq_get(args, 1) + culture = seq_get(args, 2) - fmt = args[1] - number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name) + number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)) if number_fmt: - return exp.NumberToStr(this=args[0], format=fmt) + return exp.NumberToStr(this=this, format=fmt, culture=culture) - return exp.TimeToStr( - this=args[0], - format=exp.Literal.string( + if fmt: + fmt = exp.Literal.string( format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING) if len(fmt.name) == 1 else format_time(fmt.name, TSQL.TIME_MAPPING) - ), - ) + ) + + return exp.TimeToStr(this=this, format=fmt, culture=culture) def _parse_eomonth(args: t.List) -> exp.Expression: @@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: def generate_date_delta_with_unit_sql( - self: generator.Generator, expression: exp.DateAdd | exp.DateDiff + self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff ) -> str: func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" return self.func(func, expression.text("unit"), expression.expression, expression.this) -def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: +def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: fmt = ( expression.args["format"] if isinstance(expression, exp.NumberToStr) @@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim ) ) ) - return self.func("FORMAT", expression.this, fmt) + return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) -def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: +def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() this = expression.this @@ -332,10 +334,12 @@ class TSQL(Dialect): "SQL_VARIANT": TokenType.VARIANT, "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, + "UPDATE STATISTICS": TokenType.COMMAND, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } class Parser(parser.Parser): @@ -395,7 +399,9 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True - def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + ALTER_TABLE_ADD_COLUMN_KEYWORD = False + + def _parse_projections(self) -> t.List[exp.Expression]: """ T-SQL supports the syntax alias = expression in the SELECT's projection list, so we transform all parsed Selects to convert their EQ projections into Aliases. @@ -458,43 +464,6 @@ class TSQL(Dialect): return self._parse_as_command(self._prev) - def _parse_system_time(self) -> t.Optional[exp.Expression]: - if not self._match_text_seq("FOR", "SYSTEM_TIME"): - return None - - if self._match_text_seq("AS", "OF"): - system_time = self.expression( - exp.SystemTime, this=self._parse_bitwise(), kind="AS OF" - ) - elif self._match_set((TokenType.FROM, TokenType.BETWEEN)): - kind = self._prev.text - this = self._parse_bitwise() - self._match_texts(("TO", "AND")) - expression = self._parse_bitwise() - system_time = self.expression( - exp.SystemTime, this=this, expression=expression, kind=kind - ) - elif self._match_text_seq("CONTAINED", "IN"): - args = self._parse_wrapped_csv(self._parse_bitwise) - system_time = self.expression( - exp.SystemTime, - this=seq_get(args, 0), - expression=seq_get(args, 1), - kind="CONTAINED IN", - ) - elif self._match(TokenType.ALL): - system_time = self.expression(exp.SystemTime, kind="ALL") - else: - system_time = None - self.raise_error("Unable to parse FOR SYSTEM_TIME clause") - - return system_time - - def _parse_table_parts(self, schema: bool = False) -> exp.Table: - table = super()._parse_table_parts(schema=schema) - table.set("system_time", self._parse_system_time()) - return table - def _parse_returns(self) -> exp.ReturnsProperty: table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) returns = super()._parse_returns() @@ -589,14 +558,36 @@ class TSQL(Dialect): return create + def _parse_if(self) -> t.Optional[exp.Expression]: + index = self._index + + if self._match_text_seq("OBJECT_ID"): + self._parse_wrapped_csv(self._parse_string) + if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP): + return self._parse_drop(exists=True) + self._retreat(index) + + return super()._parse_if() + + def _parse_unique(self) -> exp.UniqueColumnConstraint: + return self.expression( + exp.UniqueColumnConstraint, + this=None + if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"} + else self._parse_schema(self._parse_id_var(any_token=False)), + ) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True QUERY_HINTS = False RETURNING_END = False + NVL2_SUPPORTED = False + ALTER_TABLE_ADD_COLUMN_KEYWORD = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.INT: "INTEGER", @@ -607,6 +598,8 @@ class TSQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, + exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), @@ -651,25 +644,44 @@ class TSQL(Dialect): return sql - def offset_sql(self, expression: exp.Offset) -> str: - return f"{super().offset_sql(expression)} ROWS" + def create_sql(self, expression: exp.Create) -> str: + expression = expression.copy() + kind = self.sql(expression, "kind").upper() + exists = expression.args.pop("exists", None) + sql = super().create_sql(expression) + + if exists: + table = expression.find(exp.Table) + identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) + if kind == "SCHEMA": + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')""" + elif kind == "TABLE": + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')""" + elif kind == "INDEX": + index = self.sql(exp.Literal.string(expression.this.text("this"))) + sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')""" + elif expression.args.get("replace"): + sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) - def systemtime_sql(self, expression: exp.SystemTime) -> str: - kind = expression.args["kind"] - if kind == "ALL": - return "FOR SYSTEM_TIME ALL" + return sql - start = self.sql(expression, "this") - if kind == "AS OF": - return f"FOR SYSTEM_TIME AS OF {start}" + def offset_sql(self, expression: exp.Offset) -> str: + return f"{super().offset_sql(expression)} ROWS" - end = self.sql(expression, "expression") - if kind == "FROM": - return f"FOR SYSTEM_TIME FROM {start} TO {end}" - if kind == "BETWEEN": - return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}" + def version_sql(self, expression: exp.Version) -> str: + name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name + this = f"FOR {name}" + expr = expression.expression + kind = expression.text("kind") + if kind in ("FROM", "BETWEEN"): + args = expr.expressions + sep = "TO" if kind == "FROM" else "AND" + expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}" + else: + expr_sql = self.sql(expr) - return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})" + expr_sql = f" {expr_sql}" if expr_sql else "" + return f"{this} {kind}{expr_sql}" def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: table = expression.args.get("table") @@ -713,3 +725,16 @@ class TSQL(Dialect): identifier = f"#{identifier}" return identifier + + def constraint_sql(self, expression: exp.Constraint) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True, sep=" ") + return f"CONSTRAINT {this} {expressions}" + + # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + start = self.sql(expression, "start") or "1" + increment = self.sql(expression, "increment") or "1" + return f"IDENTITY({start}, {increment})" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 57b8bfa..0479da0 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1035,12 +1035,13 @@ class Clone(Expression): "this": True, "when": False, "kind": False, + "shallow": False, "expression": False, } class Describe(Expression): - arg_types = {"this": True, "kind": False} + arg_types = {"this": True, "kind": False, "expressions": False} class Pragma(Expression): @@ -1070,6 +1071,8 @@ class Show(Expression): "like": False, "where": False, "db": False, + "scope": False, + "scope_kind": False, "full": False, "mutex": False, "query": False, @@ -1207,6 +1210,10 @@ class Comment(Expression): arg_types = {"this": True, "kind": True, "expression": True, "exists": False} +class Comprehension(Expression): + arg_types = {"this": True, "expression": True, "iterator": True, "condition": False} + + # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl class MergeTreeTTLAction(Expression): arg_types = { @@ -1269,6 +1276,10 @@ class CheckColumnConstraint(ColumnConstraintKind): pass +class ClusteredColumnConstraint(ColumnConstraintKind): + pass + + class CollateColumnConstraint(ColumnConstraintKind): pass @@ -1316,6 +1327,14 @@ class InlineLengthColumnConstraint(ColumnConstraintKind): pass +class NonClusteredColumnConstraint(ColumnConstraintKind): + pass + + +class NotForReplicationColumnConstraint(ColumnConstraintKind): + arg_types = {} + + class NotNullColumnConstraint(ColumnConstraintKind): arg_types = {"allow_null": False} @@ -1345,6 +1364,12 @@ class PathColumnConstraint(ColumnConstraintKind): pass +# computed column expression +# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16 +class ComputedColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True, "persisted": False, "not_null": False} + + class Constraint(Expression): arg_types = {"this": True, "expressions": True} @@ -1489,6 +1514,15 @@ class Check(Expression): pass +# https://docs.snowflake.com/en/sql-reference/constructs/connect-by +class Connect(Expression): + arg_types = {"start": False, "connect": True} + + +class Prior(Expression): + pass + + class Directory(Expression): # https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html arg_types = {"this": True, "local": False, "row_format": False} @@ -1578,6 +1612,7 @@ class Insert(DDL): "alternative": False, "where": False, "ignore": False, + "by_name": False, } def with_( @@ -2045,8 +2080,12 @@ class NoPrimaryIndexProperty(Property): arg_types = {} +class OnProperty(Property): + arg_types = {"this": True} + + class OnCommitProperty(Property): - arg_type = {"delete": False} + arg_types = {"delete": False} class PartitionedByProperty(Property): @@ -2282,6 +2321,16 @@ class Subqueryable(Unionable): def named_selects(self) -> t.List[str]: raise NotImplementedError("Subqueryable objects must implement `named_selects`") + def select( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Subqueryable: + raise NotImplementedError("Subqueryable objects must implement `select`") + def with_( self, alias: ExpOrStr, @@ -2323,6 +2372,7 @@ QUERY_MODIFIERS = { "match": False, "laterals": False, "joins": False, + "connect": False, "pivots": False, "where": False, "group": False, @@ -2363,6 +2413,7 @@ class Table(Expression): "pivots": False, "hints": False, "system_time": False, + "version": False, } @property @@ -2403,21 +2454,13 @@ class Table(Expression): return parts -# See the TSQL "Querying data in a system-versioned temporal table" page -class SystemTime(Expression): - arg_types = { - "this": False, - "expression": False, - "kind": True, - } - - class Union(Subqueryable): arg_types = { "with": False, "this": True, "expression": True, "distinct": False, + "by_name": False, **QUERY_MODIFIERS, } @@ -2529,6 +2572,7 @@ class Update(Expression): "from": False, "where": False, "returning": False, + "order": False, "limit": False, } @@ -2545,6 +2589,20 @@ class Var(Expression): pass +class Version(Expression): + """ + Time travel, iceberg, bigquery etc + https://trino.io/docs/current/connector/iceberg.html?highlight=snapshot#using-snapshots + https://www.databricks.com/blog/2019/02/04/introducing-delta-time-travel-for-large-scale-data-lakes.html + https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#for_system_time_as_of + https://learn.microsoft.com/en-us/sql/relational-databases/tables/querying-data-in-a-system-versioned-temporal-table?view=sql-server-ver16 + this is either TIMESTAMP or VERSION + kind is ("AS OF", "BETWEEN") + """ + + arg_types = {"this": True, "kind": True, "expression": False} + + class Schema(Expression): arg_types = {"this": False, "expressions": False} @@ -3263,6 +3321,23 @@ class Subquery(DerivedTable, Unionable): expression = expression.this return expression + def unwrap(self) -> Subquery: + expression = self + while expression.same_parent and expression.is_wrapper: + expression = t.cast(Subquery, expression.parent) + return expression + + @property + def is_wrapper(self) -> bool: + """ + Whether this Subquery acts as a simple wrapper around another expression. + + SELECT * FROM (((SELECT * FROM t))) + ^ + This corresponds to a "wrapper" Subquery node + """ + return all(v is None for k, v in self.args.items() if k != "this") + @property def is_star(self) -> bool: return self.this.is_star @@ -3313,7 +3388,7 @@ class Pivot(Expression): } -class Window(Expression): +class Window(Condition): arg_types = { "this": True, "partition_by": False, @@ -3375,7 +3450,7 @@ class Boolean(Condition): pass -class DataTypeSize(Expression): +class DataTypeParam(Expression): arg_types = {"this": True, "expression": False} @@ -3386,6 +3461,7 @@ class DataType(Expression): "nested": False, "values": False, "prefix": False, + "kind": False, } class Type(AutoName): @@ -3432,6 +3508,7 @@ class DataType(Expression): LOWCARDINALITY = auto() MAP = auto() MEDIUMBLOB = auto() + MEDIUMINT = auto() MEDIUMTEXT = auto() MONEY = auto() NCHAR = auto() @@ -3475,6 +3552,7 @@ class DataType(Expression): VARCHAR = auto() VARIANT = auto() XML = auto() + YEAR = auto() TEXT_TYPES = { Type.CHAR, @@ -3498,7 +3576,10 @@ class DataType(Expression): Type.DOUBLE, } - NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES} + NUMERIC_TYPES = { + *INTEGER_TYPES, + *FLOAT_TYPES, + } TEMPORAL_TYPES = { Type.TIME, @@ -3511,23 +3592,39 @@ class DataType(Expression): Type.DATETIME64, } - META_TYPES = {"UNKNOWN", "NULL"} - @classmethod def build( - cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs + cls, + dtype: str | DataType | DataType.Type, + dialect: DialectType = None, + udt: bool = False, + **kwargs, ) -> DataType: + """ + Constructs a DataType object. + + Args: + dtype: the data type of interest. + dialect: the dialect to use for parsing `dtype`, in case it's a string. + udt: when set to True, `dtype` will be used as-is if it can't be parsed into a + DataType, thus creating a user-defined type. + kawrgs: additional arguments to pass in the constructor of DataType. + + Returns: + The constructed DataType object. + """ from sqlglot import parse_one if isinstance(dtype, str): - upper = dtype.upper() - if upper in DataType.META_TYPES: - data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper]) - else: - data_type_exp = parse_one(dtype, read=dialect, into=DataType) + if dtype.upper() == "UNKNOWN": + return DataType(this=DataType.Type.UNKNOWN, **kwargs) - if data_type_exp is None: - raise ValueError(f"Unparsable data type value: {dtype}") + try: + data_type_exp = parse_one(dtype, read=dialect, into=DataType) + except ParseError: + if udt: + return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) + raise elif isinstance(dtype, DataType.Type): data_type_exp = DataType(this=dtype) elif isinstance(dtype, DataType): @@ -3538,7 +3635,31 @@ class DataType(Expression): return DataType(**{**data_type_exp.args, **kwargs}) def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: - return any(self.this == DataType.build(dtype).this for dtype in dtypes) + """ + Checks whether this DataType matches one of the provided data types. Nested types or precision + will be compared using "structural equivalence" semantics, so e.g. array != array. + + Args: + dtypes: the data types to compare this DataType to. + + Returns: + True, if and only if there is a type in `dtypes` which is equal to this DataType. + """ + for dtype in dtypes: + other = DataType.build(dtype, udt=True) + + if ( + other.expressions + or self.this == DataType.Type.USERDEFINED + or other.this == DataType.Type.USERDEFINED + ): + matches = self == other + else: + matches = self.this == other.this + + if matches: + return True + return False # https://www.postgresql.org/docs/15/datatype-pseudo.html @@ -3546,6 +3667,11 @@ class PseudoType(Expression): pass +# https://www.postgresql.org/docs/15/datatype-oid.html +class ObjectIdentifier(Expression): + pass + + # WHERE x EXISTS|ALL|ANY|SOME(SELECT ...) class SubqueryPredicate(Predicate): pass @@ -4005,6 +4131,7 @@ class ArrayAny(Func): class ArrayConcat(Func): + _sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"] arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -4047,7 +4174,15 @@ class Avg(AggFunc): class AnyValue(AggFunc): - arg_types = {"this": True, "having": False, "max": False} + arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False} + + +class First(Func): + arg_types = {"this": True, "ignore_nulls": False} + + +class Last(Func): + arg_types = {"this": True, "ignore_nulls": False} class Case(Func): @@ -4086,18 +4221,29 @@ class Cast(Func): return self.name def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: - return self.to.is_type(*dtypes) + """ + Checks whether this Cast's DataType matches one of the provided data types. Nested types + like arrays or structs will be compared using "structural equivalence" semantics, so e.g. + array != array. + Args: + dtypes: the data types to compare this Cast's DataType to. -class CastToStrType(Func): - arg_types = {"this": True, "expression": True} + Returns: + True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType. + """ + return self.to.is_type(*dtypes) -class Collate(Binary): +class TryCast(Cast): pass -class TryCast(Cast): +class CastToStrType(Func): + arg_types = {"this": True, "to": True} + + +class Collate(Binary): pass @@ -4310,7 +4456,7 @@ class Greatest(Func): is_var_len_args = True -class GroupConcat(Func): +class GroupConcat(AggFunc): arg_types = {"this": True, "separator": False} @@ -4648,8 +4794,19 @@ class StrToUnix(Func): arg_types = {"this": False, "format": False} +# https://prestodb.io/docs/current/functions/string.html +# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map +class StrToMap(Func): + arg_types = { + "this": True, + "pair_delim": False, + "key_value_delim": False, + "duplicate_resolution_callback": False, + } + + class NumberToStr(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": True, "format": True, "culture": False} class FromBase(Func): @@ -4665,6 +4822,13 @@ class StructExtract(Func): arg_types = {"this": True, "expression": True} +# https://learn.microsoft.com/en-us/sql/t-sql/functions/stuff-transact-sql?view=sql-server-ver16 +# https://docs.snowflake.com/en/sql-reference/functions/insert +class Stuff(Func): + _sql_names = ["STUFF", "INSERT"] + arg_types = {"this": True, "start": True, "length": True, "expression": True} + + class Sum(AggFunc): pass @@ -4686,7 +4850,7 @@ class StddevSamp(AggFunc): class TimeToStr(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": True, "format": True, "culture": False} class TimeToTimeStr(Func): @@ -5724,9 +5888,9 @@ def table_( The new Table instance. """ return Table( - this=to_identifier(table, quoted=quoted), - db=to_identifier(db, quoted=quoted), - catalog=to_identifier(catalog, quoted=quoted), + this=to_identifier(table, quoted=quoted) if table else None, + db=to_identifier(db, quoted=quoted) if db else None, + catalog=to_identifier(catalog, quoted=quoted) if catalog else None, alias=TableAlias(this=to_identifier(alias)) if alias else None, ) @@ -5844,8 +6008,8 @@ def convert(value: t.Any, copy: bool = False) -> Expression: return Array(expressions=[convert(v, copy=copy) for v in value]) if isinstance(value, dict): return Map( - keys=[convert(k, copy=copy) for k in value], - values=[convert(v, copy=copy) for v in value.values()], + keys=Array(expressions=[convert(k, copy=copy) for k in value]), + values=Array(expressions=[convert(v, copy=copy) for v in value.values()]), ) raise ValueError(f"Cannot convert {value}") diff --git a/sqlglot/generator.py b/sqlglot/generator.py index f8d7d68..306df81 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -8,7 +8,7 @@ from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time -from sqlglot.tokens import TokenType +from sqlglot.tokens import Tokenizer, TokenType logger = logging.getLogger("sqlglot") @@ -61,6 +61,7 @@ class Generator: exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", + exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS", exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", @@ -78,7 +79,10 @@ class Generator: exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.MaterializedProperty: lambda self, e: "MATERIALIZED", exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", + exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", + exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION", exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", + exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -171,6 +175,9 @@ class Generator: # Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax TZ_TO_WITH_TIME_ZONE = False + # Whether or not the NVL2 function is supported + NVL2_SUPPORTED = True + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") @@ -179,6 +186,9 @@ class Generator: # SELECT * VALUES into SELECT UNION VALUES_AS_TABLE = True + # Whether or not the word COLUMN is included when adding a column with ALTER TABLE + ALTER_TABLE_ADD_COLUMN_KEYWORD = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -245,6 +255,7 @@ class Generator: exp.MaterializedProperty: exp.Properties.Location.POST_CREATE, exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, + exp.OnProperty: exp.Properties.Location.POST_SCHEMA, exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, exp.Order: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, @@ -317,8 +328,7 @@ class Generator: QUOTE_END = "'" IDENTIFIER_START = '"' IDENTIFIER_END = '"' - STRING_ESCAPE = "'" - IDENTIFIER_ESCAPE = '"' + TOKENIZER_CLASS = Tokenizer # Delimiters for bit, hex, byte and raw literals BIT_START: t.Optional[str] = None @@ -379,8 +389,10 @@ class Generator: ) self.unsupported_messages: t.List[str] = [] - self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END - self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END + self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END + self._escaped_identifier_end: str = ( + self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END + ) self._cache: t.Optional[t.Dict[int, str]] = None def generate( @@ -626,6 +638,16 @@ class Generator: kind_sql = self.sql(expression, "kind").strip() return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql + def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: + this = self.sql(expression, "this") + if expression.args.get("not_null"): + persisted = " PERSISTED NOT NULL" + elif expression.args.get("persisted"): + persisted = " PERSISTED" + else: + persisted = "" + return f"AS {this}{persisted}" + def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) @@ -642,8 +664,8 @@ class Generator: ) -> str: this = "" if expression.this is not None: - on_null = "ON NULL " if expression.args.get("on_null") else "" - this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}" + on_null = " ON NULL" if expression.args.get("on_null") else "" + this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}" start = expression.args.get("start") start = f"START WITH {start}" if start else "" @@ -668,7 +690,7 @@ class Generator: expr = self.sql(expression, "expression") expr = f"({expr})" if expr else "IDENTITY" - return f"GENERATED{this}AS {expr}{sequence_opts}" + return f"GENERATED{this} AS {expr}{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -774,14 +796,16 @@ class Generator: def clone_sql(self, expression: exp.Clone) -> str: this = self.sql(expression, "this") + shallow = "SHALLOW " if expression.args.get("shallow") else "" + this = f"{shallow}CLONE {this}" when = self.sql(expression, "when") if when: kind = self.sql(expression, "kind") expr = self.sql(expression, "expression") - return f"CLONE {this} {when} ({kind} => {expr})" + return f"{this} {when} ({kind} => {expr})" - return f"CLONE {this}" + return this def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" @@ -830,7 +854,7 @@ class Generator: string = self.escape_str(expression.this.replace("\\", "\\\\")) return f"{self.QUOTE_START}{string}{self.QUOTE_END}" - def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: + def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: this = self.sql(expression, "this") specifier = self.sql(expression, "expression") specifier = f" {specifier}" if specifier else "" @@ -839,11 +863,14 @@ class Generator: def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this - type_sql = ( - self.TYPE_MAPPING.get(type_value, type_value.value) - if isinstance(type_value, exp.DataType.Type) - else type_value - ) + if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"): + type_sql = self.sql(expression, "kind") + else: + type_sql = ( + self.TYPE_MAPPING.get(type_value, type_value.value) + if isinstance(type_value, exp.DataType.Type) + else type_value + ) nested = "" interior = self.expressions(expression, flat=True) @@ -943,9 +970,9 @@ class Generator: name = self.sql(expression, "this") name = f"{name} " if name else "" table = self.sql(expression, "table") - table = f"{self.INDEX_ON} {table} " if table else "" + table = f"{self.INDEX_ON} {table}" if table else "" using = self.sql(expression, "using") - using = f"USING {using} " if using else "" + using = f" USING {using} " if using else "" index = "INDEX " if not table else "" columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" @@ -1171,6 +1198,7 @@ class Generator: where = f"{self.sep()}REPLACE WHERE {where}" if where else "" expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" conflict = self.sql(expression, "conflict") + by_name = " BY NAME" if expression.args.get("by_name") else "" returning = self.sql(expression, "returning") if self.RETURNING_END: @@ -1178,7 +1206,7 @@ class Generator: else: expression_sql = f"{returning}{expression_sql}{conflict}" - sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}" + sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1196,6 +1224,9 @@ class Generator: def pseudotype_sql(self, expression: exp.PseudoType) -> str: return expression.name.upper() + def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str: + return expression.name.upper() + def onconflict_sql(self, expression: exp.OnConflict) -> str: conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" constraint = self.sql(expression, "constraint") @@ -1248,6 +1279,8 @@ class Generator: if part ) + version = self.sql(expression, "version") + version = f" {version}" if version else "" alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" hints = self.expressions(expression, key="hints", sep=" ") @@ -1256,10 +1289,8 @@ class Generator: pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="", skip_first=True) laterals = self.expressions(expression, key="laterals", sep="") - system_time = expression.args.get("system_time") - system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" + return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1314,6 +1345,12 @@ class Generator: nulls = "" return f"{direction}{nulls}({expressions} FOR {field}){alias}" + def version_sql(self, expression: exp.Version) -> str: + this = f"FOR {expression.name}" + kind = expression.text("kind") + expr = self.sql(expression, "expression") + return f"{this} {kind} {expr}" + def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" @@ -1323,12 +1360,13 @@ class Generator: from_sql = self.sql(expression, "from") where_sql = self.sql(expression, "where") returning = self.sql(expression, "returning") + order = self.sql(expression, "order") limit = self.sql(expression, "limit") if self.RETURNING_END: - expression_sql = f"{from_sql}{where_sql}{returning}{limit}" + expression_sql = f"{from_sql}{where_sql}{returning}" else: - expression_sql = f"{returning}{from_sql}{where_sql}{limit}" - sql = f"UPDATE {this} SET {set_sql}{expression_sql}" + expression_sql = f"{returning}{from_sql}{where_sql}" + sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}" return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: @@ -1425,6 +1463,16 @@ class Generator: this = self.indent(self.sql(expression, "this")) return f"{self.seg('HAVING')}{self.sep()}{this}" + def connect_sql(self, expression: exp.Connect) -> str: + start = self.sql(expression, "start") + start = self.seg(f"START WITH {start}") if start else "" + connect = self.sql(expression, "connect") + connect = self.seg(f"CONNECT BY {connect}") + return start + connect + + def prior_sql(self, expression: exp.Prior) -> str: + return f"PRIOR {self.sql(expression, 'this')}" + def join_sql(self, expression: exp.Join) -> str: op_sql = " ".join( op @@ -1667,6 +1715,7 @@ class Generator: return csv( *sqls, *[self.sql(join) for join in expression.args.get("joins") or []], + self.sql(expression, "connect"), self.sql(expression, "match"), *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], self.sql(expression, "where"), @@ -1801,7 +1850,8 @@ class Generator: def union_op(self, expression: exp.Union) -> str: kind = " DISTINCT" if self.EXPLICIT_UNION else "" kind = kind if expression.args.get("distinct") else " ALL" - return f"UNION{kind}" + by_name = " BY NAME" if expression.args.get("by_name") else "" + return f"UNION{kind}{by_name}" def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) @@ -2224,7 +2274,14 @@ class Generator: actions = expression.args["actions"] if isinstance(actions[0], exp.ColumnDef): - actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ") + if self.ALTER_TABLE_ADD_COLUMN_KEYWORD: + actions = self.expressions( + expression, + key="actions", + prefix="ADD COLUMN ", + ) + else: + actions = f"ADD {self.expressions(expression, key='actions')}" elif isinstance(actions[0], exp.Schema): actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Delete): @@ -2525,10 +2582,21 @@ class Generator: return f"WHEN {matched}{source}{condition} THEN {then}" def merge_sql(self, expression: exp.Merge) -> str: - this = self.sql(expression, "this") + table = expression.this + table_alias = "" + + hints = table.args.get("hints") + if hints and table.alias and isinstance(hints[0], exp.WithTableHint): + # T-SQL syntax is MERGE ... [WITH ()] [[AS] table_alias] + table = table.copy() + table_alias = f" AS {self.sql(table.args['alias'].pop())}" + + this = self.sql(table) using = f"USING {self.sql(expression, 'using')}" on = f"ON {self.sql(expression, 'on')}" - return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}" + expressions = self.expressions(expression, sep=" ") + + return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}" def tochar_sql(self, expression: exp.ToChar) -> str: if expression.args.get("format"): @@ -2631,6 +2699,29 @@ class Generator: options = f" {options}" if options else "" return f"{kind}{this}{type_}{schema}{options}" + def nvl2_sql(self, expression: exp.Nvl2) -> str: + if self.NVL2_SUPPORTED: + return self.function_fallback_sql(expression) + + case = exp.Case().when( + expression.this.is_(exp.null()).not_(copy=False), + expression.args["true"].copy(), + copy=False, + ) + else_cond = expression.args.get("false") + if else_cond: + case.else_(else_cond.copy(), copy=False) + + return self.sql(case) + + def comprehension_sql(self, expression: exp.Comprehension) -> str: + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + iterator = self.sql(expression, "iterator") + condition = self.sql(expression, "condition") + condition = f" IF {condition}" if condition else "" + return f"{this} FOR {expr} IN {iterator}{condition}" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/helper.py b/sqlglot/helper.py index a863017..7335d1e 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -33,6 +33,15 @@ class AutoName(Enum): return name +class classproperty(property): + """ + Similar to a normal property but works for class methods + """ + + def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: + return classmethod(self.fget).__get__(None, owner)() # type: ignore + + def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" try: @@ -137,9 +146,9 @@ def subclasses( def apply_index_offset( this: exp.Expression, - expressions: t.List[t.Optional[E]], + expressions: t.List[E], offset: int, -) -> t.List[t.Optional[E]]: +) -> t.List[E]: """ Applies an offset to a given integer literal expression. @@ -170,15 +179,14 @@ def apply_index_offset( ): return expressions - if expression: - if not expression.type: - annotate_types(expression) - if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: - logger.warning("Applying array index offset (%s)", offset) - expression = simplify( - exp.Add(this=expression.copy(), expression=exp.Literal.number(offset)) - ) - return [expression] + if not expression.type: + annotate_types(expression) + if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: + logger.warning("Applying array index offset (%s)", offset) + expression = simplify( + exp.Add(this=expression.copy(), expression=exp.Literal.number(offset)) + ) + return [expression] return expressions diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index 719a77e..ee48006 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,2 +1,9 @@ from sqlglot.optimizer.optimizer import RULES, optimize -from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope +from sqlglot.optimizer.scope import ( + Scope, + build_scope, + find_all_in_scope, + find_in_scope, + traverse_scope, + walk_in_scope, +) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index e7cb80b..a429655 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator): for expr_type in expressions }, exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), + exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), + exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), @@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), } + NESTED_TYPES = { + exp.DataType.Type.ARRAY, + } + # Specifies what types a given type can be coerced into (autofilled) COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} @@ -299,19 +308,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator): def _maybe_coerce( self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type - ) -> exp.DataType.Type: - # We propagate the NULL / UNKNOWN types upwards if found - if isinstance(type1, exp.DataType): - type1 = type1.this - if isinstance(type2, exp.DataType): - type2 = type2.this + ) -> exp.DataType | exp.DataType.Type: + type1_value = type1.this if isinstance(type1, exp.DataType) else type1 + type2_value = type2.this if isinstance(type2, exp.DataType) else type2 - if exp.DataType.Type.NULL in (type1, type2): + # We propagate the NULL / UNKNOWN types upwards if found + if exp.DataType.Type.NULL in (type1_value, type2_value): return exp.DataType.Type.NULL - if exp.DataType.Type.UNKNOWN in (type1, type2): + if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore + if type1_value in self.NESTED_TYPES: + return type1 + if type2_value in self.NESTED_TYPES: + return type2 + + return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore # Note: the following "no_type_check" decorators were added because mypy was yelling due # to assigning Type values to expression.type (since its getter returns Optional[DataType]). @@ -368,7 +380,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return self._annotate_args(expression) @t.no_type_check - def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: + def _annotate_by_args( + self, expression: E, *args: str, promote: bool = False, array: bool = False + ) -> E: self._annotate_args(expression) expressions: t.List[exp.Expression] = [] @@ -388,4 +402,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator): elif expression.type.this in exp.DataType.FLOAT_TYPES: expression.type = exp.DataType.Type.DOUBLE + if array: + expression.type = exp.DataType( + this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True + ) + return expression diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index af42f25..1ab7768 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -142,13 +142,14 @@ def _eliminate_derived_table(scope, existing_ctes, taken): if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral): return None - parent = scope.expression.parent + # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers + to_replace = scope.expression.parent.unwrap() name, cte = _new_cte(scope, existing_ctes, taken) + table = exp.alias_(exp.table_(name), alias=to_replace.alias or name) + table.set("joins", to_replace.args.get("joins")) - table = exp.alias_(exp.table_(name), alias=parent.alias or name) - table.set("joins", parent.args.get("joins")) + to_replace.replace(table) - parent.replace(table) return cte diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 7b3b2b1..9d401fc 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -72,8 +72,13 @@ def normalize(expression): if not any(join.args.get(k) for k in JOIN_ATTRS): join.set("kind", "CROSS") - if join.kind != "CROSS": + if join.kind == "CROSS": + join.set("on", None) + else: join.set("kind", None) + + if not join.args.get("on") and not join.args.get("using"): + join.set("on", exp.true()) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 58b988d..f7348b5 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,6 +1,6 @@ from sqlglot import exp from sqlglot.optimizer.normalize import normalized -from sqlglot.optimizer.scope import build_scope +from sqlglot.optimizer.scope import build_scope, find_in_scope from sqlglot.optimizer.simplify import simplify @@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count): break if isinstance(node, exp.Select): predicate.replace(exp.true()) - node.where(replace_aliases(node, predicate), copy=False) + inner_predicate = replace_aliases(node, predicate) + if find_in_scope(inner_predicate, exp.AggFunc): + node.having(inner_predicate, copy=False) + else: + node.where(inner_predicate, copy=False) def pushdown_dnf(predicates, scope, scope_ref_count): @@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count): if isinstance(node, exp.Join): node.on(predicate, copy=False) elif isinstance(node, exp.Select): - node.where(replace_aliases(node, predicate), copy=False) + inner_predicate = replace_aliases(node, predicate) + if find_in_scope(inner_predicate, exp.AggFunc): + node.having(inner_predicate, copy=False) + else: + node.where(inner_predicate, copy=False) def nodes_for_predicate(predicate, sources, scope_ref_count): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index fb12384..435899a 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -6,7 +6,7 @@ from enum import Enum, auto from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.helper import find_new_name +from sqlglot.helper import ensure_collection, find_new_name logger = logging.getLogger("sqlglot") @@ -141,38 +141,10 @@ class Scope: return walk_in_scope(self.expression, bfs=bfs) def find(self, *expression_types, bfs=True): - """ - Returns the first node in this scope which matches at least one of the specified types. - - This does NOT traverse into subscopes. - - Args: - expression_types (type): the expression type(s) to match. - bfs (bool): True to use breadth-first search, False to use depth-first. - - Returns: - exp.Expression: the node which matches the criteria or None if no node matching - the criteria was found. - """ - return next(self.find_all(*expression_types, bfs=bfs), None) + return find_in_scope(self.expression, expression_types, bfs=bfs) def find_all(self, *expression_types, bfs=True): - """ - Returns a generator object which visits all nodes in this scope and only yields those that - match at least one of the specified expression types. - - This does NOT traverse into subscopes. - - Args: - expression_types (type): the expression type(s) to match. - bfs (bool): True to use breadth-first search, False to use depth-first. - - Yields: - exp.Expression: nodes - """ - for expression, *_ in self.walk(bfs=bfs): - if isinstance(expression, expression_types): - yield expression + return find_all_in_scope(self.expression, expression_types, bfs=bfs) def replace(self, old, new): """ @@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True): for key in ("joins", "laterals", "pivots"): for arg in node.args.get(key) or []: yield from walk_in_scope(arg, bfs=bfs) + + +def find_all_in_scope(expression, expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this scope and only yields those that + match at least one of the specified expression types. + + This does NOT traverse into subscopes. + + Args: + expression (exp.Expression): + expression_types (tuple[type]|type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Yields: + exp.Expression: nodes + """ + for expression, *_ in walk_in_scope(expression, bfs=bfs): + if isinstance(expression, tuple(ensure_collection(expression_types))): + yield expression + + +def find_in_scope(expression, expression_types, bfs=True): + """ + Returns the first node in this scope which matches at least one of the specified types. + + This does NOT traverse into subscopes. + + Args: + expression (exp.Expression): + expression_types (tuple[type]|type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Returns: + exp.Expression: the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index e550603..3974ea4 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -69,10 +69,10 @@ def simplify(expression): node = flatten(node) node = simplify_connectors(node, root) node = remove_compliments(node, root) + node = simplify_coalesce(node) node.parent = expression.parent node = simplify_literals(node, root) node = simplify_parens(node) - node = simplify_coalesce(node) if root: expression.replace(node) @@ -350,7 +350,8 @@ def absorb_and_eliminate(expression, root=True): def simplify_literals(expression, root=True): if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): return _flat_simplify(expression, _simplify_binary, root) - elif isinstance(expression, exp.Neg): + + if isinstance(expression, exp.Neg): this = expression.this if this.is_number: value = this.name @@ -430,13 +431,14 @@ def simplify_parens(expression): if not isinstance(this, exp.Select) and ( not isinstance(parent, (exp.Condition, exp.Binary)) - or isinstance(this, exp.Predicate) + or isinstance(parent, exp.Paren) or not isinstance(this, exp.Binary) + or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) ): - return expression.this + return this return expression @@ -488,18 +490,20 @@ def simplify_coalesce(expression): coalesce = coalesce if coalesce.expressions else coalesce.this # This expression is more complex than when we started, but it will get simplified further - return exp.or_( - exp.and_( - coalesce.is_(exp.null()).not_(copy=False), - expression.copy(), - copy=False, - ), - exp.and_( - coalesce.is_(exp.null()), - type(expression)(this=arg.copy(), expression=other.copy()), + return exp.paren( + exp.or_( + exp.and_( + coalesce.is_(exp.null()).not_(copy=False), + expression.copy(), + copy=False, + ), + exp.and_( + coalesce.is_(exp.null()), + type(expression)(this=arg.copy(), expression=other.copy()), + copy=False, + ), copy=False, - ), - copy=False, + ) ) @@ -642,7 +646,7 @@ def _flat_simplify(expression, simplifier, root=True): for b in queue: result = simplifier(expression, a, b) - if result: + if result and result is not expression: queue.remove(b) queue.appendleft(result) break diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 3db4453..f8690d5 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -136,6 +136,7 @@ class Parser(metaclass=_Parser): TokenType.UINT128, TokenType.INT256, TokenType.UINT256, + TokenType.MEDIUMINT, TokenType.FIXEDSTRING, TokenType.FLOAT, TokenType.DOUBLE, @@ -186,6 +187,7 @@ class Parser(metaclass=_Parser): TokenType.SMALLSERIAL, TokenType.BIGSERIAL, TokenType.XML, + TokenType.YEAR, TokenType.UNIQUEIDENTIFIER, TokenType.USERDEFINED, TokenType.MONEY, @@ -194,9 +196,12 @@ class Parser(metaclass=_Parser): TokenType.IMAGE, TokenType.VARIANT, TokenType.OBJECT, + TokenType.OBJECT_IDENTIFIER, TokenType.INET, TokenType.IPADDRESS, TokenType.IPPREFIX, + TokenType.UNKNOWN, + TokenType.NULL, *ENUM_TYPE_TOKENS, *NESTED_TYPE_TOKENS, } @@ -332,6 +337,7 @@ class Parser(metaclass=_Parser): TokenType.INDEX, TokenType.ISNULL, TokenType.ILIKE, + TokenType.INSERT, TokenType.LIKE, TokenType.MERGE, TokenType.OFFSET, @@ -487,7 +493,7 @@ class Parser(metaclass=_Parser): exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), exp.Column: lambda self: self._parse_column(), exp.Condition: lambda self: self._parse_conjunction(), - exp.DataType: lambda self: self._parse_types(), + exp.DataType: lambda self: self._parse_types(allow_identifiers=False), exp.Expression: lambda self: self._parse_statement(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), @@ -523,9 +529,6 @@ class Parser(metaclass=_Parser): TokenType.DESC: lambda self: self._parse_describe(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), - TokenType.FROM: lambda self: exp.select("*").from_( - t.cast(exp.From, self._parse_from(skip_from_token=True)) - ), TokenType.INSERT: lambda self: self._parse_insert(), TokenType.LOAD: lambda self: self._parse_load(), TokenType.MERGE: lambda self: self._parse_merge(), @@ -578,7 +581,7 @@ class Parser(metaclass=_Parser): TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), TokenType.PARAMETER: lambda self: self._parse_parameter(), TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text) - if self._match_set((TokenType.NUMBER, TokenType.VAR)) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) else None, } @@ -593,6 +596,7 @@ class Parser(metaclass=_Parser): TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), TokenType.RLIKE: binary_range_parser(exp.RegexpLike), TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo), + TokenType.FOR: lambda self, this: self._parse_comprehension(this), } PROPERTY_PARSERS: t.Dict[str, t.Callable] = { @@ -684,6 +688,12 @@ class Parser(metaclass=_Parser): exp.CommentColumnConstraint, this=self._parse_string() ), "COMPRESS": lambda self: self._parse_compress(), + "CLUSTERED": lambda self: self.expression( + exp.ClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered) + ), + "NONCLUSTERED": lambda self: self.expression( + exp.NonClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered) + ), "DEFAULT": lambda self: self.expression( exp.DefaultColumnConstraint, this=self._parse_bitwise() ), @@ -698,8 +708,11 @@ class Parser(metaclass=_Parser): "LIKE": lambda self: self._parse_create_like(), "NOT": lambda self: self._parse_not_constraint(), "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True), - "ON": lambda self: self._match(TokenType.UPDATE) - and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()), + "ON": lambda self: ( + self._match(TokenType.UPDATE) + and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()) + ) + or self.expression(exp.OnProperty, this=self._parse_id_var()), "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), "PRIMARY KEY": lambda self: self._parse_primary_key(), "REFERENCES": lambda self: self._parse_references(match=False), @@ -709,6 +722,9 @@ class Parser(metaclass=_Parser): "TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]), "UNIQUE": lambda self: self._parse_unique(), "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), + "WITH": lambda self: self.expression( + exp.Properties, expressions=self._parse_wrapped_csv(self._parse_property) + ), } ALTER_PARSERS = { @@ -728,6 +744,11 @@ class Parser(metaclass=_Parser): "NEXT": lambda self: self._parse_next_value_for(), } + INVALID_FUNC_NAME_TOKENS = { + TokenType.IDENTIFIER, + TokenType.STRING, + } + FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} FUNCTION_PARSERS = { @@ -774,6 +795,8 @@ class Parser(metaclass=_Parser): self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), ), TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)), + TokenType.CONNECT_BY: lambda self: ("connect", self._parse_connect(skip_start_token=True)), + TokenType.START_WITH: lambda self: ("connect", self._parse_connect()), } SET_PARSERS = { @@ -815,6 +838,8 @@ class Parser(metaclass=_Parser): ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} + DISTINCT_TOKENS = {TokenType.DISTINCT} + STRICT_CAST = True # A NULL arg in CONCAT yields NULL by default @@ -826,6 +851,11 @@ class Parser(metaclass=_Parser): LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False + SUPPORTS_USER_DEFINED_TYPES = True + + # Whether or not ADD is present for each column added by ALTER TABLE + ALTER_TABLE_ADD_COLUMN_KEYWORD = True + __slots__ = ( "error_level", "error_message_context", @@ -838,9 +868,11 @@ class Parser(metaclass=_Parser): "_next", "_prev", "_prev_comments", + "_tokenizer", ) # Autofilled + TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer INDEX_OFFSET: int = 0 UNNEST_COLUMN_ONLY: bool = False ALIAS_POST_TABLESAMPLE: bool = False @@ -863,6 +895,7 @@ class Parser(metaclass=_Parser): self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context self.max_errors = max_errors + self._tokenizer = self.TOKENIZER_CLASS() self.reset() def reset(self): @@ -1148,7 +1181,7 @@ class Parser(metaclass=_Parser): expression = self._parse_set_operations(expression) if expression else self._parse_select() return self._parse_query_modifiers(expression) - def _parse_drop(self) -> exp.Drop | exp.Command: + def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command: start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match_text_seq("MATERIALIZED") @@ -1160,7 +1193,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Drop, comments=start.comments, - exists=self._parse_exists(), + exists=exists or self._parse_exists(), this=self._parse_table(schema=True), kind=kind, temporary=temporary, @@ -1274,6 +1307,8 @@ class Parser(metaclass=_Parser): if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"): no_schema_binding = True + shallow = self._match_text_seq("SHALLOW") + if self._match_text_seq("CLONE"): clone = self._parse_table(schema=True) when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper() @@ -1285,7 +1320,12 @@ class Parser(metaclass=_Parser): clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise() self._match(TokenType.R_PAREN) clone = self.expression( - exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression + exp.Clone, + this=clone, + when=when, + kind=clone_kind, + shallow=shallow, + expression=clone_expression, ) return self.expression( @@ -1349,7 +1389,11 @@ class Parser(metaclass=_Parser): if assignment: key = self._parse_var_or_string() self._match(TokenType.EQ) - return self.expression(exp.Property, this=key, value=self._parse_column()) + return self.expression( + exp.Property, + this=key, + value=self._parse_column() or self._parse_var(any_token=True), + ) return None @@ -1409,7 +1453,7 @@ class Parser(metaclass=_Parser): def _parse_with_property( self, - ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]: + ) -> t.Optional[exp.Expression] | t.List[exp.Expression]: if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_property) @@ -1622,7 +1666,7 @@ class Parser(metaclass=_Parser): override=override, ) - def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]: + def _parse_partition_by(self) -> t.List[exp.Expression]: if self._match(TokenType.PARTITION_BY): return self._parse_csv(self._parse_conjunction) return [] @@ -1652,9 +1696,9 @@ class Parser(metaclass=_Parser): def _parse_on_property(self) -> t.Optional[exp.Expression]: if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): return exp.OnCommitProperty() - elif self._match_text_seq("COMMIT", "DELETE", "ROWS"): + if self._match_text_seq("COMMIT", "DELETE", "ROWS"): return exp.OnCommitProperty(delete=True) - return None + return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var())) def _parse_distkey(self) -> exp.DistKeyProperty: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) @@ -1709,8 +1753,10 @@ class Parser(metaclass=_Parser): def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text - this = self._parse_table() - return self.expression(exp.Describe, this=this, kind=kind) + this = self._parse_table(schema=True) + properties = self._parse_properties() + expressions = properties.expressions if properties else None + return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions) def _parse_insert(self) -> exp.Insert: comments = ensure_list(self._prev_comments) @@ -1741,6 +1787,7 @@ class Parser(metaclass=_Parser): exp.Insert, comments=comments, this=this, + by_name=self._match_text_seq("BY", "NAME"), exists=self._parse_exists(), partition=self._parse_partition(), where=self._match_pair(TokenType.REPLACE, TokenType.WHERE) @@ -1895,6 +1942,7 @@ class Parser(metaclass=_Parser): "from": self._parse_from(joins=True), "where": self._parse_where(), "returning": returning or self._parse_returning(), + "order": self._parse_order(), "limit": self._parse_limit(), }, ) @@ -1948,13 +1996,14 @@ class Parser(metaclass=_Parser): # https://prestodb.io/docs/current/sql/values.html return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) - def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + def _parse_projections(self) -> t.List[exp.Expression]: return self._parse_expressions() def _parse_select( self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True ) -> t.Optional[exp.Expression]: cte = self._parse_with() + if cte: this = self._parse_statement() @@ -1967,12 +2016,18 @@ class Parser(metaclass=_Parser): else: self.raise_error(f"{this.key} does not support CTE") this = cte - elif self._match(TokenType.SELECT): + + return this + + # duckdb supports leading with FROM x + from_ = self._parse_from() if self._match(TokenType.FROM, advance=False) else None + + if self._match(TokenType.SELECT): comments = self._prev_comments hint = self._parse_hint() all_ = self._match(TokenType.ALL) - distinct = self._match(TokenType.DISTINCT) + distinct = self._match_set(self.DISTINCT_TOKENS) kind = ( self._match(TokenType.ALIAS) @@ -2006,7 +2061,9 @@ class Parser(metaclass=_Parser): if into: this.set("into", into) - from_ = self._parse_from() + if not from_: + from_ = self._parse_from() + if from_: this.set("from", from_) @@ -2033,6 +2090,8 @@ class Parser(metaclass=_Parser): expressions=self._parse_csv(self._parse_value), alias=self._parse_table_alias(), ) + elif from_: + this = exp.select("*").from_(from_.this, copy=False) else: this = None @@ -2491,6 +2550,11 @@ class Parser(metaclass=_Parser): if schema: return self._parse_schema(this=this) + version = self._parse_version() + + if version: + this.set("version", version) + if self.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() @@ -2498,11 +2562,11 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) + this.set("hints", self._parse_table_hints()) + if not this.args.get("pivots"): this.set("pivots", self._parse_pivots()) - this.set("hints", self._parse_table_hints()) - if not self.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() @@ -2516,6 +2580,37 @@ class Parser(metaclass=_Parser): return this + def _parse_version(self) -> t.Optional[exp.Version]: + if self._match(TokenType.TIMESTAMP_SNAPSHOT): + this = "TIMESTAMP" + elif self._match(TokenType.VERSION_SNAPSHOT): + this = "VERSION" + else: + return None + + if self._match_set((TokenType.FROM, TokenType.BETWEEN)): + kind = self._prev.text.upper() + start = self._parse_bitwise() + self._match_texts(("TO", "AND")) + end = self._parse_bitwise() + expression: t.Optional[exp.Expression] = self.expression( + exp.Tuple, expressions=[start, end] + ) + elif self._match_text_seq("CONTAINED", "IN"): + kind = "CONTAINED IN" + expression = self.expression( + exp.Tuple, expressions=self._parse_wrapped_csv(self._parse_bitwise) + ) + elif self._match(TokenType.ALL): + kind = "ALL" + expression = None + else: + self._match_text_seq("AS", "OF") + kind = "AS OF" + expression = self._parse_type() + + return self.expression(exp.Version, this=this, expression=expression, kind=kind) + def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: if not self._match(TokenType.UNNEST): return None @@ -2760,7 +2855,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Group, **elements) # type: ignore - def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + def _parse_grouping_sets(self) -> t.Optional[t.List[exp.Expression]]: if not self._match(TokenType.GROUPING_SETS): return None @@ -2784,6 +2879,22 @@ class Parser(metaclass=_Parser): return None return self.expression(exp.Qualify, this=self._parse_conjunction()) + def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]: + if skip_start_token: + start = None + elif self._match(TokenType.START_WITH): + start = self._parse_conjunction() + else: + return None + + self._match(TokenType.CONNECT_BY) + self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression( + exp.Prior, this=self._parse_bitwise() + ) + connect = self._parse_conjunction() + self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR") + return self.expression(exp.Connect, start=start, connect=connect) + def _parse_order( self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False ) -> t.Optional[exp.Expression]: @@ -2929,6 +3040,7 @@ class Parser(metaclass=_Parser): expression, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), + by_name=self._match_text_seq("BY", "NAME"), expression=self._parse_set_operations(self._parse_select(nested=True)), ) @@ -3017,6 +3129,8 @@ class Parser(metaclass=_Parser): return self.expression(exp.Escape, this=this, expression=self._parse_string()) def _parse_interval(self) -> t.Optional[exp.Interval]: + index = self._index + if not self._match(TokenType.INTERVAL): return None @@ -3025,7 +3139,11 @@ class Parser(metaclass=_Parser): else: this = self._parse_term() - unit = self._parse_function() or self._parse_var() + if not this: + self._retreat(index) + return None + + unit = self._parse_function() or self._parse_var(any_token=True) # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse # each INTERVAL expression into this canonical form so it's easy to transpile @@ -3036,12 +3154,12 @@ class Parser(metaclass=_Parser): if len(parts) == 2: if unit: - # this is not actually a unit, it's something else + # This is not actually a unit, it's something else (e.g. a "window side") unit = None self._retreat(self._index - 1) - else: - this = exp.Literal.string(parts[0]) - unit = self.expression(exp.Var, this=parts[1]) + + this = exp.Literal.string(parts[0]) + unit = self.expression(exp.Var, this=parts[1]) return self.expression(exp.Interval, this=this, unit=unit) @@ -3087,7 +3205,7 @@ class Parser(metaclass=_Parser): return interval index = self._index - data_type = self._parse_types(check_func=True) + data_type = self._parse_types(check_func=True, allow_identifiers=False) this = self._parse_column() if data_type: @@ -3103,30 +3221,50 @@ class Parser(metaclass=_Parser): return this - def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]: + def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]: this = self._parse_type() if not this: return None return self.expression( - exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True) + exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True) ) def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: index = self._index prefix = self._match_text_seq("SYSUDTLIB", ".") if not self._match_set(self.TYPE_TOKENS): - return None + identifier = allow_identifiers and self._parse_id_var( + any_token=False, tokens=(TokenType.VAR,) + ) + + if identifier: + tokens = self._tokenizer.tokenize(identifier.name) + + if len(tokens) != 1: + self.raise_error("Unexpected identifier", self._prev) + + if tokens[0].token_type in self.TYPE_TOKENS: + self._prev = tokens[0] + elif self.SUPPORTS_USER_DEFINED_TYPES: + return identifier + else: + return None + else: + return None type_token = self._prev.token_type if type_token == TokenType.PSEUDO_TYPE: return self.expression(exp.PseudoType, this=self._prev.text) + if type_token == TokenType.OBJECT_IDENTIFIER: + return self.expression(exp.ObjectIdentifier, this=self._prev.text) + nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token in self.STRUCT_TYPE_TOKENS expressions = None @@ -3137,7 +3275,9 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(self._parse_struct_types) elif nested: expressions = self._parse_csv( - lambda: self._parse_types(check_func=check_func, schema=schema) + lambda: self._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) ) elif type_token in self.ENUM_TYPE_TOKENS: expressions = self._parse_csv(self._parse_equality) @@ -3151,14 +3291,16 @@ class Parser(metaclass=_Parser): maybe_func = True this: t.Optional[exp.Expression] = None - values: t.Optional[t.List[t.Optional[exp.Expression]]] = None + values: t.Optional[t.List[exp.Expression]] = None if nested and self._match(TokenType.LT): if is_struct: expressions = self._parse_csv(self._parse_struct_types) else: expressions = self._parse_csv( - lambda: self._parse_types(check_func=check_func, schema=schema) + lambda: self._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) ) if not self._match(TokenType.GT): @@ -3355,7 +3497,7 @@ class Parser(metaclass=_Parser): upper = this.upper() parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper) - if optional_parens and parser: + if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS: self._advance() return parser(self) @@ -3442,7 +3584,9 @@ class Parser(metaclass=_Parser): index = self._index if self._match(TokenType.L_PAREN): - expressions = self._parse_csv(self._parse_id_var) + expressions = t.cast( + t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var) + ) if not self._match(TokenType.R_PAREN): self._retreat(index) @@ -3481,14 +3625,14 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return this - args = self._parse_csv( - lambda: self._parse_constraint() - or self._parse_column_def(self._parse_field(any_token=True)) - ) + args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def()) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) + def _parse_field_def(self) -> t.Optional[exp.Expression]: + return self._parse_column_def(self._parse_field(any_token=True)) + def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: # column defs are not really columns, they're identifiers if isinstance(this, exp.Column): @@ -3499,7 +3643,18 @@ class Parser(metaclass=_Parser): if self._match_text_seq("FOR", "ORDINALITY"): return self.expression(exp.ColumnDef, this=this, ordinality=True) - constraints = [] + constraints: t.List[exp.Expression] = [] + + if not kind and self._match(TokenType.ALIAS): + constraints.append( + self.expression( + exp.ComputedColumnConstraint, + this=self._parse_conjunction(), + persisted=self._match_text_seq("PERSISTED"), + not_null=self._match_pair(TokenType.NOT, TokenType.NULL), + ) + ) + while True: constraint = self._parse_column_constraint() if not constraint: @@ -3553,7 +3708,7 @@ class Parser(metaclass=_Parser): identity = self._match_text_seq("IDENTITY") if self._match(TokenType.L_PAREN): - if self._match_text_seq("START", "WITH"): + if self._match(TokenType.START_WITH): this.set("start", self._parse_bitwise()) if self._match_text_seq("INCREMENT", "BY"): this.set("increment", self._parse_bitwise()) @@ -3580,11 +3735,13 @@ class Parser(metaclass=_Parser): def _parse_not_constraint( self, - ) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]: + ) -> t.Optional[exp.Expression]: if self._match_text_seq("NULL"): return self.expression(exp.NotNullColumnConstraint) if self._match_text_seq("CASESPECIFIC"): return self.expression(exp.CaseSpecificColumnConstraint, not_=True) + if self._match_text_seq("FOR", "REPLICATION"): + return self.expression(exp.NotForReplicationColumnConstraint) return None def _parse_column_constraint(self) -> t.Optional[exp.Expression]: @@ -3729,7 +3886,7 @@ class Parser(metaclass=_Parser): bracket_kind = self._prev.token_type if self._match(TokenType.COLON): - expressions: t.List[t.Optional[exp.Expression]] = [ + expressions: t.List[exp.Expression] = [ self.expression(exp.Slice, expression=self._parse_conjunction()) ] else: @@ -3844,17 +4001,17 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.ALIAS): if self._match(TokenType.COMMA): - return self.expression( - exp.CastToStrType, this=this, expression=self._parse_string() - ) - else: - self.raise_error("Expected AS after CAST") + return self.expression(exp.CastToStrType, this=this, to=self._parse_string()) + + self.raise_error("Expected AS after CAST") fmt = None to = self._parse_types() if not to: self.raise_error("Expected TYPE after CAST") + elif isinstance(to, exp.Identifier): + to = exp.DataType.build(to.name, udt=True) elif to.this == exp.DataType.Type.CHAR: if self._match(TokenType.CHARACTER_SET): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) @@ -3908,7 +4065,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.COMMA): args.extend(self._parse_csv(self._parse_conjunction)) else: - args = self._parse_csv(self._parse_conjunction) + args = self._parse_csv(self._parse_conjunction) # type: ignore index = self._index if not self._match(TokenType.R_PAREN) and args: @@ -3991,10 +4148,10 @@ class Parser(metaclass=_Parser): def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: self._match_text_seq("KEY") - key = self._parse_field() - self._match(TokenType.COLON) + key = self._parse_column() + self._match_set((TokenType.COLON, TokenType.COMMA)) self._match_text_seq("VALUE") - value = self._parse_field() + value = self._parse_bitwise() if not key and not value: return None @@ -4116,7 +4273,7 @@ class Parser(metaclass=_Parser): # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 - args = self._parse_csv(self._parse_bitwise) + args = t.cast(t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise)) if self._match(TokenType.FROM): args.append(self._parse_bitwise()) @@ -4149,7 +4306,7 @@ class Parser(metaclass=_Parser): exp.Trim, this=this, position=position, expression=expression, collation=collation ) - def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]: return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) def _parse_named_window(self) -> t.Optional[exp.Expression]: @@ -4216,8 +4373,7 @@ class Parser(metaclass=_Parser): if self._match_text_seq("LAST"): first = False - partition = self._parse_partition_by() - order = self._parse_order() + partition, order = self._parse_partition_and_order() kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text if kind: @@ -4256,6 +4412,11 @@ class Parser(metaclass=_Parser): return window + def _parse_partition_and_order( + self, + ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: + return self._parse_partition_by(), self._parse_order() + def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: self._match(TokenType.BETWEEN) @@ -4377,14 +4538,14 @@ class Parser(metaclass=_Parser): self._advance(-1) return None - def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + def _parse_except(self) -> t.Optional[t.List[exp.Expression]]: if not self._match(TokenType.EXCEPT): return None if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_column) return self._parse_csv(self._parse_column) - def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]: if not self._match(TokenType.REPLACE): return None if self._match(TokenType.L_PAREN, advance=False): @@ -4393,7 +4554,7 @@ class Parser(metaclass=_Parser): def _parse_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA - ) -> t.List[t.Optional[exp.Expression]]: + ) -> t.List[exp.Expression]: parse_result = parse_method() items = [parse_result] if parse_result is not None else [] @@ -4420,12 +4581,12 @@ class Parser(metaclass=_Parser): return this - def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]: + def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]: return self._parse_wrapped_csv(self._parse_id_var, optional=optional) def _parse_wrapped_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False - ) -> t.List[t.Optional[exp.Expression]]: + ) -> t.List[exp.Expression]: return self._parse_wrapped( lambda: self._parse_csv(parse_method, sep=sep), optional=optional ) @@ -4439,7 +4600,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() return parse_result - def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]: + def _parse_expressions(self) -> t.List[exp.Expression]: return self._parse_csv(self._parse_expression) def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]: @@ -4498,7 +4659,7 @@ class Parser(metaclass=_Parser): self._match(TokenType.COLUMN) exists_column = self._parse_exists(not_=True) - expression = self._parse_column_def(self._parse_field(any_token=True)) + expression = self._parse_field_def() if expression: expression.set("exists", exists_column) @@ -4549,13 +4710,16 @@ class Parser(metaclass=_Parser): return self.expression(exp.AddConstraint, this=this, expression=expression) - def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]: + def _parse_alter_table_add(self) -> t.List[exp.Expression]: index = self._index - 1 if self._match_set(self.ADD_CONSTRAINT_TOKENS): return self._parse_csv(self._parse_add_constraint) self._retreat(index) + if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"): + return self._parse_csv(self._parse_field_def) + return self._parse_csv(self._parse_add_column) def _parse_alter_table_alter(self) -> exp.AlterColumn: @@ -4576,7 +4740,7 @@ class Parser(metaclass=_Parser): using=self._match(TokenType.USING) and self._parse_conjunction(), ) - def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]: + def _parse_alter_table_drop(self) -> t.List[exp.Expression]: index = self._index - 1 partition_exists = self._parse_exists() @@ -4619,6 +4783,9 @@ class Parser(metaclass=_Parser): self._match(TokenType.INTO) target = self._parse_table() + if target and self._match(TokenType.ALIAS, advance=False): + target.set("alias", self._parse_table_alias()) + self._match(TokenType.USING) using = self._parse_table() @@ -4685,8 +4852,7 @@ class Parser(metaclass=_Parser): parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) if parser: return parser(self) - self._advance() - return self.expression(exp.Show, this=self._prev.text.upper()) + return self._parse_as_command(self._prev) def _parse_set_item_assignment( self, kind: t.Optional[str] = None @@ -4786,6 +4952,19 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.DictRange, this=this, min=min, max=max) + def _parse_comprehension(self, this: exp.Expression) -> exp.Comprehension: + expression = self._parse_column() + self._match(TokenType.IN) + iterator = self._parse_column() + condition = self._parse_conjunction() if self._match_text_seq("IF") else None + return self.expression( + exp.Comprehension, + this=this, + expression=expression, + iterator=iterator, + condition=condition, + ) + def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index d278dbf..83b97d6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -48,6 +48,7 @@ class TokenType(AutoName): HASH_ARROW = auto() DHASH_ARROW = auto() LR_ARROW = auto() + DAT = auto() LT_AT = auto() AT_GT = auto() DOLLAR = auto() @@ -84,6 +85,7 @@ class TokenType(AutoName): UTINYINT = auto() SMALLINT = auto() USMALLINT = auto() + MEDIUMINT = auto() INT = auto() UINT = auto() BIGINT = auto() @@ -140,6 +142,7 @@ class TokenType(AutoName): SMALLSERIAL = auto() BIGSERIAL = auto() XML = auto() + YEAR = auto() UNIQUEIDENTIFIER = auto() USERDEFINED = auto() MONEY = auto() @@ -157,6 +160,7 @@ class TokenType(AutoName): FIXEDSTRING = auto() LOWCARDINALITY = auto() NESTED = auto() + UNKNOWN = auto() # keywords ALIAS = auto() @@ -180,6 +184,7 @@ class TokenType(AutoName): COMMAND = auto() COMMENT = auto() COMMIT = auto() + CONNECT_BY = auto() CONSTRAINT = auto() CREATE = auto() CROSS = auto() @@ -256,6 +261,7 @@ class TokenType(AutoName): NEXT = auto() NOTNULL = auto() NULL = auto() + OBJECT_IDENTIFIER = auto() OFFSET = auto() ON = auto() ORDER_BY = auto() @@ -298,6 +304,7 @@ class TokenType(AutoName): SIMILAR_TO = auto() SOME = auto() SORT_BY = auto() + START_WITH = auto() STRUCT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() @@ -319,6 +326,8 @@ class TokenType(AutoName): WINDOW = auto() WITH = auto() UNIQUE = auto() + VERSION_SNAPSHOT = auto() + TIMESTAMP_SNAPSHOT = auto() class Token: @@ -530,6 +539,7 @@ class Tokenizer(metaclass=_Tokenizer): "COLLATE": TokenType.COLLATE, "COLUMN": TokenType.COLUMN, "COMMIT": TokenType.COMMIT, + "CONNECT BY": TokenType.CONNECT_BY, "CONSTRAINT": TokenType.CONSTRAINT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, @@ -636,6 +646,7 @@ class Tokenizer(metaclass=_Tokenizer): "SIMILAR TO": TokenType.SIMILAR_TO, "SOME": TokenType.SOME, "SORT BY": TokenType.SORT_BY, + "START WITH": TokenType.START_WITH, "TABLE": TokenType.TABLE, "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, @@ -643,6 +654,7 @@ class Tokenizer(metaclass=_Tokenizer): "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "UNION": TokenType.UNION, + "UNKNOWN": TokenType.UNKNOWN, "UNNEST": TokenType.UNNEST, "UNPIVOT": TokenType.UNPIVOT, "UPDATE": TokenType.UPDATE, @@ -739,6 +751,8 @@ class Tokenizer(metaclass=_Tokenizer): "TRUNCATE": TokenType.COMMAND, "VACUUM": TokenType.COMMAND, "USER-DEFINED": TokenType.USERDEFINED, + "FOR VERSION": TokenType.VERSION_SNAPSHOT, + "FOR TIMESTAMP": TokenType.TIMESTAMP_SNAPSHOT, } WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = { @@ -941,8 +955,8 @@ class Tokenizer(metaclass=_Tokenizer): if result == TrieResult.EXISTS: word = chars + end = self._current + size size += 1 - end = self._current - 1 + size if end < self.size: char = self.sql[end] @@ -961,21 +975,20 @@ class Tokenizer(metaclass=_Tokenizer): char = "" chars = " " - if not word: - if self._char in self.SINGLE_TOKENS: - self._add(self.SINGLE_TOKENS[self._char], text=self._char) + if word: + if self._scan_string(word): return - self._scan_var() - return - - if self._scan_string(word): - return - if self._scan_comment(word): + if self._scan_comment(word): + return + if prev_space or single_token or not char: + self._advance(size - 1) + word = word.upper() + self._add(self.KEYWORDS[word], text=word) + return + if self._char in self.SINGLE_TOKENS: + self._add(self.SINGLE_TOKENS[self._char], text=self._char) return - - self._advance(size - 1) - word = word.upper() - self._add(self.KEYWORDS[word], text=word) + self._scan_var() def _scan_comment(self, comment_start: str) -> bool: if comment_start not in self._COMMENTS: @@ -1053,8 +1066,8 @@ class Tokenizer(metaclass=_Tokenizer): elif self.IDENTIFIERS_CAN_START_WITH_DIGIT: return self._add(TokenType.VAR) - self._add(TokenType.NUMBER, number_text) - return self._advance(-len(literal)) + self._advance(-len(literal)) + return self._add(TokenType.NUMBER, number_text) else: return self._add(TokenType.NUMBER) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 7c7c2a7..48ea8dc 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -68,11 +68,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: if order: window.set("order", order.pop().copy()) + else: + window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) window = exp.alias_(window, row_number) expression.select(window, copy=False) - return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') + return ( + exp.select(*outer_selects) + .from_(expression.subquery()) + .where(exp.column(row_number).eq(1)) + ) return expression @@ -126,7 +132,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr """ for node in expression.find_all(exp.DataType): node.set( - "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)] + "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] ) return expression -- cgit v1.2.3