diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-30 05:07:13 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-30 05:07:13 +0000 |
commit | 276f5723c8ab7e0b2938693db546dbd30be3f01a (patch) | |
tree | e6294c63de34a03e373245ec4cb1efbca1edfe61 | |
parent | Adding upstream version 6.2.1. (diff) | |
download | sqlglot-276f5723c8ab7e0b2938693db546dbd30be3f01a.tar.xz sqlglot-276f5723c8ab7e0b2938693db546dbd30be3f01a.zip |
Adding upstream version 6.2.6.upstream/6.2.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
49 files changed, 1739 insertions, 564 deletions
diff --git a/run_checks.sh b/run_checks.sh index 770f443..b6e559d 100755 --- a/run_checks.sh +++ b/run_checks.sh @@ -1,6 +1,8 @@ #!/bin/bash -e -python -m autoflake -i -r \ +[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check' + +python -m autoflake -i -r ${RETURN_ERROR_CODE} \ --expand-star-imports \ --remove-all-unused-imports \ --ignore-init-module-imports \ @@ -8,5 +10,5 @@ python -m autoflake -i -r \ --remove-unused-variables \ sqlglot/ tests/ python -m isort --profile black sqlglot/ tests/ -python -m black --line-length 120 sqlglot/ tests/ +python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/ python -m unittest diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 9ac1759..befbc8a 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -20,7 +20,7 @@ from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.2.1" +__version__ = "6.2.6" pretty = False diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1f1f90a..432fd8c 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -33,6 +33,49 @@ def _date_add_sql(data_type, kind): return func +def _subquery_to_unnest_if_values(self, expression): + if not isinstance(expression.this, exp.Values): + return self.subquery_sql(expression) + rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)] + structs = [] + for row in rows: + aliases = [ + exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) + ] + structs.append(exp.Struct(expressions=aliases)) + unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) + return self.unnest_sql(unnest_exp) + + +def _returnsproperty_sql(self, expression): + value = expression.args.get("value") + if isinstance(value, exp.Schema): + value = f"{value.this} <{self.expressions(value)}>" + else: + value = self.sql(value) + return f"RETURNS {value}" + + +def _create_sql(self, expression): + kind = expression.args.get("kind") + returns = expression.find(exp.ReturnsProperty) + if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): + expression = expression.copy() + expression.set("kind", "TABLE FUNCTION") + if isinstance( + expression.expression, + ( + exp.Subquery, + exp.Literal, + ), + ): + expression.set("expression", expression.expression.this) + + return self.create_sql(expression) + + return self.create_sql(expression) + + class BigQuery(Dialect): unnest_column_only = True @@ -77,8 +120,14 @@ class BigQuery(Dialect): TokenType.CURRENT_TIME: exp.CurrentTime, } + NESTED_TYPE_TOKENS = { + *Parser.NESTED_TYPE_TOKENS, + TokenType.TABLE, + } + class Generator(Generator): TRANSFORMS = { + **Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), @@ -91,6 +140,9 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.VariancePop: rename_func("VAR_POP"), + exp.Subquery: _subquery_to_unnest_if_values, + exp.ReturnsProperty: _returnsproperty_sql, + exp.Create: _create_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0120e71..0ab584e 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -245,6 +245,11 @@ def no_tablesample_sql(self, expression): return self.sql(expression.this) +def no_pivot_sql(self, expression): + self.unsupported("PIVOT unsupported") + return self.sql(expression) + + def no_trycast_sql(self, expression): return self.cast_sql(expression) @@ -282,3 +287,30 @@ def format_time_lambda(exp_class, dialect, default=None): ) return _format_time + + +def create_with_partitions_sql(self, expression): + """ + In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the + PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding + columns are removed from the create statement. + """ + has_schema = isinstance(expression.this, exp.Schema) + is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") + + if has_schema and is_partitionable: + expression = expression.copy() + prop = expression.find(exp.PartitionedByProperty) + value = prop and prop.args.get("value") + if prop and not isinstance(value, exp.Schema): + schema = expression.this + columns = {v.name.upper() for v in value.expressions} + partitions = [col for col in schema.expressions if col.name.upper() in columns] + schema.set( + "expressions", + [e for e in schema.expressions if e not in partitions], + ) + prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) + expression.set("this", schema) + + return self.create_sql(expression) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4ca9e84..e09c3dd 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, format_time_lambda, + no_pivot_sql, no_safe_divide_sql, no_tablesample_sql, rename_func, @@ -122,6 +123,7 @@ class DuckDB(Dialect): exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.Pivot: no_pivot_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 59aa8fa..7a27bb3 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -2,6 +2,7 @@ from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, + create_with_partitions_sql, format_time_lambda, if_sql, no_ilike_sql, @@ -53,7 +54,7 @@ def _array_sort(self, expression): def _property_sql(self, expression): key = expression.name value = self.sql(expression, "value") - return f"'{key}' = {value}" + return f"'{key}'={value}" def _str_to_unix(self, expression): @@ -218,15 +219,6 @@ class Hive(Dialect): } class Generator(Generator): - ROOT_PROPERTIES = [ - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.SchemaCommentProperty, - exp.LocationProperty, - exp.TableFormatProperty, - ] - WITH_PROPERTIES = [exp.AnonymousProperty] - TYPE_MAPPING = { **Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", @@ -255,13 +247,13 @@ class Hive(Dialect): exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: _map_sql, HiveMap: _map_sql, - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", + exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.SafeDivide: no_safe_divide_sql, - exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}", + exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", @@ -282,6 +274,17 @@ class Hive(Dialect): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + } + + WITH_PROPERTIES = {exp.AnonymousProperty} + + ROOT_PROPERTIES = { + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.LocationProperty, + exp.TableFormatProperty, } def with_properties(self, properties): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 87a2c41..8449379 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -172,6 +172,11 @@ class MySQL(Dialect): ), } + PROPERTY_PARSERS = { + **Parser.PROPERTY_PARSERS, + TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), + } + class Generator(Generator): NULL_ORDERING_SUPPORTED = False @@ -190,3 +195,13 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, } + + ROOT_PROPERTIES = { + exp.EngineProperty, + exp.AutoIncrementProperty, + exp.CharacterSetProperty, + exp.CollateProperty, + exp.SchemaCommentProperty, + } + + WITH_PROPERTIES = {} diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c796839..aaa07a1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + str_position_sql, ) from sqlglot.generator import Generator from sqlglot.parser import Parser @@ -158,7 +159,6 @@ class Postgres(Dialect): "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, "IDENTITY": TokenType.IDENTITY, - "FOR": TokenType.FOR, "GENERATED": TokenType.GENERATED, "DOUBLE PRECISION": TokenType.DOUBLE, "BIGSERIAL": TokenType.BIGSERIAL, @@ -204,6 +204,7 @@ class Postgres(Dialect): exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), exp.Lateral: _lateral_sql, + exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 7253f7e..85647c5 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -146,13 +146,16 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") - WITH_PROPERTIES = [ + ROOT_PROPERTIES = { + exp.SchemaCommentProperty, + } + + WITH_PROPERTIES = { exp.PartitionedByProperty, exp.FileFormatProperty, - exp.SchemaCommentProperty, exp.AnonymousProperty, exp.TableFormatProperty, - ] + } TYPE_MAPPING = { **Generator.TYPE_MAPPING, @@ -184,13 +187,11 @@ class Presto(Dialect): exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", - exp.FileFormatProperty: lambda self, e: self.property_sql(e), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", exp.Quantile: _quantile_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index b5d4f0a..1b718f7 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,5 +1,10 @@ from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + inline_array_sql, + rename_func, +) from sqlglot.expressions import Literal from sqlglot.generator import Generator from sqlglot.helper import list_get @@ -104,6 +109,8 @@ class Snowflake(Dialect): "ARRAYAGG": exp.ArrayAgg.from_arg_list, "IFF": exp.If.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, + "ARRAY_CONSTRUCT": exp.Array.from_arg_list, + "RLIKE": exp.RegexpLike.from_arg_list, } FUNCTION_PARSERS = { @@ -111,6 +118,11 @@ class Snowflake(Dialect): "DATE_PART": lambda self: self._parse_extract(), } + FUNC_TOKENS = { + *Parser.FUNC_TOKENS, + TokenType.RLIKE, + } + COLUMN_OPERATORS = { **Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( @@ -120,6 +132,11 @@ class Snowflake(Dialect): ), } + PROPERTY_PARSERS = { + **Parser.PROPERTY_PARSERS, + TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), + } + class Tokenizer(Tokenizer): QUOTES = ["'", "$$"] ESCAPE = "\\" @@ -137,6 +154,7 @@ class Snowflake(Dialect): "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, + "SAMPLE": TokenType.TABLE_SAMPLE, } class Generator(Generator): @@ -145,6 +163,8 @@ class Snowflake(Dialect): exp.If: rename_func("IFF"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: _unix_to_time, + exp.Array: inline_array_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", } TYPE_MAPPING = { @@ -152,6 +172,13 @@ class Snowflake(Dialect): exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } + ROOT_PROPERTIES = { + exp.PartitionedByProperty, + exp.ReturnsProperty, + exp.LanguageProperty, + exp.SchemaCommentProperty, + } + def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index c051178..5446e83 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,5 +1,9 @@ from sqlglot import exp -from sqlglot.dialects.dialect import no_ilike_sql, rename_func +from sqlglot.dialects.dialect import ( + create_with_partitions_sql, + no_ilike_sql, + rename_func, +) from sqlglot.dialects.hive import Hive, HiveMap from sqlglot.helper import list_get @@ -10,7 +14,7 @@ def _create_sql(self, e): if kind.upper() == "TABLE" and temporary is True: return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return self.create_sql(e) + return create_with_partitions_sql(self, e) def _map_sql(self, expression): @@ -73,6 +77,7 @@ class Spark(Hive): } class Generator(Hive.Generator): + TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index b9cd584..ef8c82d 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,4 +1,5 @@ from sqlglot import exp +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.mysql import MySQL @@ -10,3 +11,12 @@ class StarRocks(MySQL): exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } + + TRANSFORMS = { + **MySQL.Generator.TRANSFORMS, + exp.DateDiff: rename_func("DATEDIFF"), + exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 68bb9bd..73b232e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,6 +1,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import Dialect from sqlglot.generator import Generator +from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType @@ -17,6 +18,7 @@ class TSQL(Dialect): "REAL": TokenType.FLOAT, "NTEXT": TokenType.TEXT, "SMALLDATETIME": TokenType.DATETIME, + "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "TIME": TokenType.TIMESTAMP, "VARBINARY": TokenType.BINARY, @@ -24,15 +26,24 @@ class TSQL(Dialect): "MONEY": TokenType.MONEY, "SMALLMONEY": TokenType.SMALLMONEY, "ROWVERSION": TokenType.ROWVERSION, - "SQL_VARIANT": TokenType.SQL_VARIANT, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "XML": TokenType.XML, + "SQL_VARIANT": TokenType.VARIANT, } + class Parser(Parser): + def _parse_convert(self): + to = self._parse_types() + self._match(TokenType.COMMA) + this = self._parse_field() + return self.expression(exp.Cast, this=this, to=to) + class Generator(Generator): TYPE_MAPPING = { **Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.DATETIME: "DATETIME2", + exp.DataType.Type.VARIANT: "SQL_VARIANT", } diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index bca9f3e..e765616 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -3,17 +3,11 @@ import time from sqlglot import parse_one from sqlglot.executor.python import PythonExecutor -from sqlglot.optimizer import RULES, optimize -from sqlglot.optimizer.merge_derived_tables import merge_derived_tables +from sqlglot.optimizer import optimize from sqlglot.planner import Plan logger = logging.getLogger("sqlglot") -OPTIMIZER_RULES = list(RULES) - -# The executor needs isolated table selects -OPTIMIZER_RULES.remove(merge_derived_tables) - def execute(sql, schema, read=None): """ @@ -34,7 +28,7 @@ def execute(sql, schema, read=None): """ expression = parse_one(sql, read=read) now = time.time() - expression = optimize(expression, schema, rules=OPTIMIZER_RULES) + expression = optimize(expression, schema, leave_tables_isolated=True) logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) plan = Plan(expression) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index de615d6..599c7db 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,13 +1,17 @@ -import inspect import numbers import re -import sys from collections import deque from copy import deepcopy from enum import auto from sqlglot.errors import ParseError -from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get +from sqlglot.helper import ( + AutoName, + camel_to_snake_case, + ensure_list, + list_get, + subclasses, +) class _Expression(type): @@ -31,12 +35,13 @@ class Expression(metaclass=_Expression): key = None arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key") + __slots__ = ("args", "parent", "arg_key", "type") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None + self.type = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -384,7 +389,7 @@ class Expression(metaclass=_Expression): 'SELECT y FROM tbl' Args: - expression (Expression): new node + expression (Expression|None): new node Returns : the new expression or expressions @@ -398,6 +403,12 @@ class Expression(metaclass=_Expression): replace_children(parent, lambda child: expression if child is self else child) return expression + def pop(self): + """ + Remove this expression from its AST. + """ + self.replace(None) + def assert_is(self, type_): """ Assert that this `Expression` is an instance of `type_`. @@ -527,9 +538,18 @@ class Create(Expression): "temporary": False, "replace": False, "unique": False, + "materialized": False, } +class UserDefinedFunction(Expression): + arg_types = {"this": True, "expressions": False} + + +class UserDefinedFunctionKwarg(Expression): + arg_types = {"this": True, "kind": True, "default": False} + + class CharacterSet(Expression): arg_types = {"this": True, "default": False} @@ -887,6 +907,14 @@ class AnonymousProperty(Property): pass +class ReturnsProperty(Property): + arg_types = {"this": True, "value": True, "is_table": False} + + +class LanguageProperty(Property): + pass + + class Properties(Expression): arg_types = {"expressions": True} @@ -907,25 +935,9 @@ class Properties(Expression): expressions = [] for key, value in properties_dict.items(): property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) - expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value))) + expressions.append(property_cls(this=Literal.string(key), value=convert(value))) return cls(expressions=expressions) - @staticmethod - def _convert_value(value): - if value is None: - return NULL - if isinstance(value, Expression): - return value - if isinstance(value, bool): - return Boolean(this=value) - if isinstance(value, str): - return Literal.string(value) - if isinstance(value, numbers.Number): - return Literal.number(value) - if isinstance(value, list): - return Tuple(expressions=[Properties._convert_value(v) for v in value]) - raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'") - class Qualify(Expression): pass @@ -1030,6 +1042,7 @@ class Subqueryable: QUERY_MODIFIERS = { "laterals": False, "joins": False, + "pivots": False, "where": False, "group": False, "having": False, @@ -1051,6 +1064,7 @@ class Table(Expression): "catalog": False, "laterals": False, "joins": False, + "pivots": False, } @@ -1643,6 +1657,16 @@ class TableSample(Expression): "percent": False, "rows": False, "size": False, + "seed": False, + } + + +class Pivot(Expression): + arg_types = { + "this": False, + "expressions": True, + "field": True, + "unpivot": True, } @@ -1741,7 +1765,8 @@ class DataType(Expression): SMALLMONEY = auto() ROWVERSION = auto() IMAGE = auto() - SQL_VARIANT = auto() + VARIANT = auto() + OBJECT = auto() @classmethod def build(cls, dtype, **kwargs): @@ -2124,6 +2149,7 @@ class TryCast(Cast): class Ceil(Func): + arg_types = {"this": True, "decimals": False} _sql_names = ["CEIL", "CEILING"] @@ -2254,7 +2280,7 @@ class Explode(Func): class Floor(Func): - pass + arg_types = {"this": True, "decimals": False} class Greatest(Func): @@ -2371,7 +2397,7 @@ class Reduce(Func): class RegexpLike(Func): - arg_types = {"this": True, "expression": True} + arg_types = {"this": True, "expression": True, "flag": False} class RegexpSplit(Func): @@ -2540,6 +2566,8 @@ def _norm_args(expression): for k, arg in expression.args.items(): if isinstance(arg, list): arg = [_norm_arg(a) for a in arg] + if not arg: + arg = None else: arg = _norm_arg(arg) @@ -2553,17 +2581,7 @@ def _norm_arg(arg): return arg.lower() if isinstance(arg, str) else arg -def _all_functions(): - return [ - obj - for _, obj in inspect.getmembers( - sys.modules[__name__], - lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func), - ) - ] - - -ALL_FUNCTIONS = _all_functions() +ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) def maybe_parse( @@ -2793,6 +2811,37 @@ def from_(*expressions, dialect=None, **opts): return Select().from_(*expressions, dialect=dialect, **opts) +def update(table, properties, where=None, from_=None, dialect=None, **opts): + """ + Creates an update statement. + + Example: + >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql() + "UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1" + + Args: + *properties (Dict[str, Any]): dictionary of properties to set which are + auto converted to sql objects eg None -> NULL + where (str): sql conditional parsed into a WHERE statement + from_ (str): sql statement parsed into a FROM statement + dialect (str): the dialect used to parse the input expressions. + **opts: other options to use to parse the input expressions. + + Returns: + Update: the syntax tree for the UPDATE statement. + """ + update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) + update.set( + "expressions", + [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()], + ) + if from_: + update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) + if where: + update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) + return update + + def condition(expression, dialect=None, **opts): """ Initialize a logical condition expression. @@ -2980,12 +3029,13 @@ def column(col, table=None, quoted=None): def table_(table, db=None, catalog=None, quoted=None): - """ - Build a Table. + """Build a Table. + Args: table (str or Expression): column name db (str or Expression): db name catalog (str or Expression): catalog name + Returns: Table: table instance """ @@ -2996,6 +3046,39 @@ def table_(table, db=None, catalog=None, quoted=None): ) +def convert(value): + """Convert a python value into an expression object. + + Raises an error if a conversion is not possible. + + Args: + value (Any): a python object + + Returns: + Expression: the equivalent expression object + """ + if isinstance(value, Expression): + return value + if value is None: + return NULL + if isinstance(value, bool): + return Boolean(this=value) + if isinstance(value, str): + return Literal.string(value) + if isinstance(value, numbers.Number): + return Literal.number(value) + if isinstance(value, tuple): + return Tuple(expressions=[convert(v) for v in value]) + if isinstance(value, list): + return Array(expressions=[convert(v) for v in value]) + if isinstance(value, dict): + return Map( + keys=[convert(k) for k in value.keys()], + values=[convert(v) for v in value.values()], + ) + raise ValueError(f"Cannot convert {value}") + + def replace_children(expression, fun): """ Replace children of an expression with the result of a lambda fun(child) -> exp. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d264e59..9099307 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -46,18 +46,12 @@ class Generator: """ TRANSFORMS = { - exp.AnonymousProperty: lambda self, e: self.property_sql(e), - exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", - exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}", - exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}", - exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}", - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}", - exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}", - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}", + exp.LanguageProperty: lambda self, e: self.naked_property(e), + exp.LocationProperty: lambda self, e: self.naked_property(e), + exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", } @@ -72,19 +66,17 @@ class Generator: STRUCT_DELIMITER = ("<", ">") - ROOT_PROPERTIES = [ - exp.AutoIncrementProperty, - exp.CharacterSetProperty, - exp.CollateProperty, - exp.EngineProperty, - exp.SchemaCommentProperty, - ] - WITH_PROPERTIES = [ + ROOT_PROPERTIES = { + exp.ReturnsProperty, + exp.LanguageProperty, + } + + WITH_PROPERTIES = { exp.AnonymousProperty, exp.FileFormatProperty, exp.PartitionedByProperty, exp.TableFormatProperty, - ] + } __slots__ = ( "time_mapping", @@ -188,6 +180,7 @@ class Generator: return sql def unsupported(self, message): + if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) @@ -261,6 +254,9 @@ class Generator: if isinstance(expression, exp.Func): return self.function_fallback_sql(expression) + if isinstance(expression, exp.Property): + return self.property_sql(expression) + raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") def annotation_sql(self, expression): @@ -352,9 +348,12 @@ class Generator: replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" + materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") - expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}" + expression_sql = ( + f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" + ) return self.prepend_ctes(expression, expression_sql) def prepend_ctes(self, expression, sql): @@ -461,10 +460,10 @@ class Generator: for p in expression.expressions: p_class = p.__class__ - if p_class in self.ROOT_PROPERTIES: - root_properties.append(p) - elif p_class in self.WITH_PROPERTIES: + if p_class in self.WITH_PROPERTIES: with_properties.append(p) + elif p_class in self.ROOT_PROPERTIES: + root_properties.append(p) return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( exp.Properties(expressions=with_properties) @@ -496,9 +495,12 @@ class Generator: ) def property_sql(self, expression): - key = expression.name + if isinstance(expression.this, exp.Literal): + key = expression.this.this + else: + key = expression.name value = self.sql(expression, "value") - return f"{key} = {value}" + return f"{key}={value}" def insert_sql(self, expression): kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" @@ -535,7 +537,8 @@ class Generator: laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") - return f"{table}{laterals}{joins}" + pivots = self.expressions(expression, key="pivots", sep="") + return f"{table}{laterals}{joins}{pivots}" def tablesample_sql(self, expression): if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): @@ -556,7 +559,17 @@ class Generator: rows = self.sql(expression, "rows") rows = f"{rows} ROWS" if rows else "" size = self.sql(expression, "size") - return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}" + seed = self.sql(expression, "seed") + seed = f" SEED ({seed})" if seed else "" + return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}" + + def pivot_sql(self, expression): + this = self.sql(expression, "this") + unpivot = expression.args.get("unpivot") + direction = "UNPIVOT" if unpivot else "PIVOT" + expressions = self.expressions(expression, key="expressions") + field = self.sql(expression, "field") + return f"{this} {direction}({expressions} FOR {field})" def tuple_sql(self, expression): return f"({self.expressions(expression, flat=True)})" @@ -681,6 +694,7 @@ class Generator: def ordered_sql(self, expression): desc = expression.args.get("desc") asc = not desc + nulls_first = expression.args.get("nulls_first") nulls_last = not nulls_first nulls_are_large = self.null_ordering == "nulls_are_large" @@ -760,6 +774,7 @@ class Generator: return self.query_modifiers( expression, self.wrap(expression), + self.expressions(expression, key="pivots", sep=" "), f" AS {alias}" if alias else "", ) @@ -1129,6 +1144,9 @@ class Generator: return f"{op} {expressions_sql}" return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" + def naked_property(self, expression): + return f"{expression.name} {self.sql(expression, 'value')}" + def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) @@ -1136,3 +1154,13 @@ class Generator: def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) + + def userdefinedfunction_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.no_identify(lambda: self.expressions(expression)) + return f"{this}({expressions})" + + def userdefinedfunctionkwarg_sql(self, expression): + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + return f"{this} {kind}" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 5d90c49..c4dd91e 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -1,5 +1,7 @@ +import inspect import logging import re +import sys from contextlib import contextmanager from enum import Enum @@ -29,6 +31,26 @@ def csv(*args, sep=", "): return sep.join(arg for arg in args if arg) +def subclasses(module_name, classes, exclude=()): + """ + Returns a list of all subclasses for a specified class set, posibly excluding some of them. + + Args: + module_name (str): The name of the module to search for subclasses in. + classes (type|tuple[type]): Class(es) we want to find the subclasses of. + exclude (type|tuple[type]): Class(es) we want to exclude from the returned list. + Returns: + A list of all the target subclasses. + """ + return [ + obj + for _, obj in inspect.getmembers( + sys.modules[module_name], + lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, + ) + ] + + def apply_index_offset(expressions, offset): if not offset or len(expressions) != 1: return expressions @@ -100,7 +122,7 @@ def csv_reader(table): Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) Args: - expression (Expression): An anonymous function READ_CSV + table (exp.Table): A table expression with an anonymous function READ_CSV in it Returns: A python csv reader. @@ -121,3 +143,22 @@ def csv_reader(table): yield csv_.reader(file, delimiter=delimiter) finally: file.close() + + +def find_new_name(taken, base): + """ + Searches for a new name. + + Args: + taken (Sequence[str]): set of taken names + base (str): base name to alter + """ + if base not in taken: + return base + + i = 2 + new = f"{base}_{i}" + while new in taken: + i += 1 + new = f"{base}_{i}" + return new diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py new file mode 100644 index 0000000..3f5f089 --- /dev/null +++ b/sqlglot/optimizer/annotate_types.py @@ -0,0 +1,162 @@ +from sqlglot import exp +from sqlglot.helper import ensure_list, subclasses + + +def annotate_types(expression, schema=None, annotators=None, coerces_to=None): + """ + Recursively infer & annotate types in an expression syntax tree against a schema. + + (TODO -- replace this with a better example after adding some functionality) + Example: + >>> import sqlglot + >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) + >>> annotated_expression.type + <Type.DOUBLE: 'DOUBLE'> + + Args: + expression (sqlglot.Expression): Expression to annotate. + schema (dict|sqlglot.optimizer.Schema): Database schema. + annotators (dict): Maps expression type to corresponding annotation function. + coerces_to (dict): Maps expression type to set of types that it can be coerced into. + Returns: + sqlglot.Expression: expression annotated with types + """ + + return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) + + +class TypeAnnotator: + ANNOTATORS = { + **{ + expr_type: lambda self, expr: self._annotate_unary(expr) + for expr_type in subclasses(exp.__name__, exp.Unary) + }, + **{ + expr_type: lambda self, expr: self._annotate_binary(expr) + for expr_type in subclasses(exp.__name__, exp.Binary) + }, + exp.Cast: lambda self, expr: self._annotate_cast(expr), + exp.DataType: lambda self, expr: self._annotate_data_type(expr), + exp.Literal: lambda self, expr: self._annotate_literal(expr), + exp.Boolean: lambda self, expr: self._annotate_boolean(expr), + } + + # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html + COERCES_TO = { + # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT + exp.DataType.Type.TEXT: set(), + exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, + exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.CHAR: { + exp.DataType.Type.NCHAR, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.TEXT, + }, + # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE + exp.DataType.Type.DOUBLE: set(), + exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, + exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.INT: { + exp.DataType.Type.BIGINT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, + exp.DataType.Type.SMALLINT: { + exp.DataType.Type.INT, + exp.DataType.Type.BIGINT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, + exp.DataType.Type.TINYINT: { + exp.DataType.Type.SMALLINT, + exp.DataType.Type.INT, + exp.DataType.Type.BIGINT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, + # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ + exp.DataType.Type.TIMESTAMPLTZ: set(), + exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.DATETIME: { + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, + exp.DataType.Type.DATE: { + exp.DataType.Type.DATETIME, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, + } + + def __init__(self, schema=None, annotators=None, coerces_to=None): + self.schema = schema + self.annotators = annotators or self.ANNOTATORS + self.coerces_to = coerces_to or self.COERCES_TO + + def annotate(self, expression): + if not isinstance(expression, exp.Expression): + return None + + annotator = self.annotators.get(expression.__class__) + return annotator(self, expression) if annotator else self._annotate_args(expression) + + def _annotate_args(self, expression): + for value in expression.args.values(): + for v in ensure_list(value): + self.annotate(v) + + return expression + + def _annotate_cast(self, expression): + expression.type = expression.args["to"].this + return self._annotate_args(expression) + + def _annotate_data_type(self, expression): + expression.type = expression.this + return self._annotate_args(expression) + + def _maybe_coerce(self, type1, type2): + return type2 if type2 in self.coerces_to[type1] else type1 + + def _annotate_binary(self, expression): + self._annotate_args(expression) + + if isinstance(expression, (exp.Condition, exp.Predicate)): + expression.type = exp.DataType.Type.BOOLEAN + else: + expression.type = self._maybe_coerce(expression.left.type, expression.right.type) + + return expression + + def _annotate_unary(self, expression): + self._annotate_args(expression) + + if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): + expression.type = exp.DataType.Type.BOOLEAN + else: + expression.type = expression.this.type + + return expression + + def _annotate_literal(self, expression): + if expression.is_string: + expression.type = exp.DataType.Type.VARCHAR + elif expression.is_int: + expression.type = exp.DataType.Type.INT + else: + expression.type = exp.DataType.Type.DOUBLE + + return expression + + def _annotate_boolean(self, expression): + expression.type = exp.DataType.Type.BOOLEAN + return expression diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 4bfb733..38e1299 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -1,48 +1,144 @@ import itertools -from sqlglot import alias, exp, select, table -from sqlglot.optimizer.scope import traverse_scope +from sqlglot import expressions as exp +from sqlglot.helper import find_new_name +from sqlglot.optimizer.scope import build_scope from sqlglot.optimizer.simplify import simplify def eliminate_subqueries(expression): """ - Rewrite duplicate subqueries from sqlglot AST. + Rewrite subqueries as CTES, deduplicating if possible. Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") >>> eliminate_subqueries(expression).sql() - 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0' + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' + + This also deduplicates common subqueries: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") + >>> eliminate_subqueries(expression).sql() + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' Args: - expression (sqlglot.Expression): expression to qualify - schema (dict|sqlglot.optimizer.Schema): Database schema + expression (sqlglot.Expression): expression Returns: - sqlglot.Expression: qualified expression + sqlglot.Expression: expression """ + if isinstance(expression, exp.Subquery): + # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 + eliminate_subqueries(expression.this) + return expression + expression = simplify(expression) - queries = {} + root = build_scope(expression) + + # Map of alias->Scope|Table + # These are all aliases that are already used in the expression. + # We don't want to create new CTEs that conflict with these names. + taken = {} + + # All CTE aliases in the root scope are taken + for scope in root.cte_scopes: + taken[scope.expression.parent.alias] = scope + + # All table names are taken + for scope in root.traverse(): + taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)}) - for scope in traverse_scope(expression): - query = scope.expression - queries[query] = queries.get(query, []) + [query] + # Map of Expression->alias + # Existing CTES in the root expression. We'll use this for deduplication. + existing_ctes = {} - sequence = itertools.count() + with_ = root.expression.args.get("with") + if with_: + for cte in with_.expressions: + existing_ctes[cte.this] = cte.alias + new_ctes = [] - for query, duplicates in queries.items(): - if len(duplicates) == 1: - continue + # We're adding more CTEs, but we want to maintain the DAG order. + # Derived tables within an existing CTE need to come before the existing CTE. + for cte_scope in root.cte_scopes: + # Append all the new CTEs from this existing CTE + for scope in cte_scope.traverse(): + new_cte = _eliminate(scope, existing_ctes, taken) + if new_cte: + new_ctes.append(new_cte) - alias_ = f"_e_{next(sequence)}" + # Append the existing CTE itself + new_ctes.append(cte_scope.expression.parent) - for dup in duplicates: - parent = dup.parent - if isinstance(parent, exp.Subquery): - parent.replace(alias(table(alias_), parent.alias_or_name, table=True)) - elif isinstance(parent, exp.Union): - dup.replace(select("*").from_(alias_)) + # Now append the rest + for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes): + for child_scope in scope.traverse(): + new_cte = _eliminate(child_scope, existing_ctes, taken) + if new_cte: + new_ctes.append(new_cte) - expression.with_(alias_, as_=query, copy=False) + if new_ctes: + expression.set("with", exp.With(expressions=new_ctes)) return expression + + +def _eliminate(scope, existing_ctes, taken): + if scope.is_union: + return _eliminate_union(scope, existing_ctes, taken) + + if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)): + return _eliminate_derived_table(scope, existing_ctes, taken) + + +def _eliminate_union(scope, existing_ctes, taken): + duplicate_cte_alias = existing_ctes.get(scope.expression) + + alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte") + + taken[alias] = scope + + # Try to maintain the selections + expressions = scope.expression.args.get("expressions") + selects = [ + exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) + for e in expressions + if e.alias_or_name + ] + # If not all selections have an alias, just select * + if len(selects) != len(expressions): + selects = ["*"] + + scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias))) + + if not duplicate_cte_alias: + existing_ctes[scope.expression] = alias + return exp.CTE( + this=scope.expression, + alias=exp.TableAlias(this=exp.to_identifier(alias)), + ) + + +def _eliminate_derived_table(scope, existing_ctes, taken): + duplicate_cte_alias = existing_ctes.get(scope.expression) + parent = scope.expression.parent + name = alias = parent.alias + + if not alias: + name = alias = find_new_name(taken=taken, base="cte") + + if duplicate_cte_alias: + name = duplicate_cte_alias + elif taken.get(alias): + name = find_new_name(taken=taken, base=alias) + + taken[name] = scope + + table = exp.alias_(exp.table_(name), alias=alias) + parent.replace(table) + + if not duplicate_cte_alias: + existing_ctes[scope.expression] = name + return exp.CTE( + this=scope.expression, + alias=exp.TableAlias(this=exp.to_identifier(name)), + ) diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_subqueries.py index 8b161fb..9d966b7 100644 --- a/sqlglot/optimizer/merge_derived_tables.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -1,72 +1,127 @@ from collections import defaultdict from sqlglot import expressions as exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.helper import find_new_name +from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.simplify import simplify -def merge_derived_tables(expression): +def merge_subqueries(expression, leave_tables_isolated=False): """ Rewrite sqlglot AST to merge derived tables into the outer query. + This also merges CTEs if they are selected from only once. + Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") - >>> merge_derived_tables(expression).sql() - 'SELECT x.a FROM x' + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression).sql() + 'SELECT x.a FROM x JOIN y' + + If `leave_tables_isolated` is True, this will not merge inner queries into outer + queries if it would result in multiple table selects in a single query: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression, leave_tables_isolated=True).sql() + 'SELECT a FROM (SELECT x.a FROM x) JOIN y' Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html Args: expression (sqlglot.Expression): expression to optimize + leave_tables_isolated (bool): Returns: sqlglot.Expression: optimized expression """ + merge_ctes(expression, leave_tables_isolated) + merge_derived_tables(expression, leave_tables_isolated) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def merge_ctes(expression, leave_tables_isolated=False): + scopes = traverse_scope(expression) + + # All places where we select from CTEs. + # We key on the CTE scope so we can detect CTES that are selected from multiple times. + cte_selections = defaultdict(list) + for outer_scope in scopes: + for table, inner_scope in outer_scope.selected_sources.values(): + if isinstance(inner_scope, Scope) and inner_scope.is_cte: + cte_selections[id(inner_scope)].append( + ( + outer_scope, + inner_scope, + table, + ) + ) + + singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] + for outer_scope, inner_scope, table in singular_cte_selections: + inner_select = inner_scope.expression.unnest() + if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = table.find_ancestor(exp.From, exp.Join) + + node_to_replace = table + if isinstance(node_to_replace.parent, exp.Alias): + node_to_replace = node_to_replace.parent + alias = node_to_replace.alias + else: + alias = table.name + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, node_to_replace, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + _pop_cte(inner_scope) + + +def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if ( - isinstance(outer_scope.expression, exp.Select) - and isinstance(inner_select, exp.Select) - and _mergeable(inner_select) - ): + if _mergeable(outer_scope, inner_select, leave_tables_isolated): alias = subquery.alias_or_name from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, subquery) + _merge_from(outer_scope, inner_scope, subquery, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_expressions(outer_scope, inner_scope, alias) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) - return expression -# If a derived table has these Select args, it can't be merged -UNMERGABLE_ARGS = set(exp.Select.arg_types) - { - "expressions", - "from", - "joins", - "where", - "order", -} - - -def _mergeable(inner_select): +def _mergeable(outer_scope, inner_select, leave_tables_isolated): """ Return True if `inner_select` can be merged into outer query. Args: + outer_scope (Scope) inner_select (exp.Select) + leave_tables_isolated (bool) Returns: bool: True if can be merged """ return ( - isinstance(inner_select, exp.Select) + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) ) @@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): conflicts = conflicts - {alias} for conflict in conflicts: - new_name = _find_new_name(taken, conflict) + new_name = find_new_name(taken, conflict) source, _ = inner_scope.selected_sources[conflict] new_alias = exp.to_identifier(new_name) @@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): inner_scope.rename_source(conflict, new_name) -def _find_new_name(taken, base): - """ - Searches for a new source name. - - Args: - taken (set[str]): set of taken names - base (str): base name to alter - """ - i = 2 - new = f"{base}_{i}" - while new in taken: - i += 1 - new = f"{base}_{i}" - return new - - -def _merge_from(outer_scope, inner_scope, subquery): +def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ Merge FROM clause of inner query into outer query. Args: outer_scope (sqlglot.optimizer.scope.Scope) inner_scope (sqlglot.optimizer.scope.Scope) - subquery (exp.Subquery) + node_to_replace (exp.Subquery|exp.Table) + alias (str) """ new_subquery = inner_scope.expression.args.get("from").expressions[0] - subquery.replace(new_subquery) - outer_scope.remove_source(subquery.alias_or_name) + node_to_replace.replace(new_subquery) + outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias): inner_scope (sqlglot.optimizer.scope.Scope) alias (str) """ - # Collect all columns that for the alias of the inner query + # Collect all columns that reference the alias of the inner query outer_columns = defaultdict(list) for column in outer_scope.columns: if column.table == alias: @@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if not where or not where.this: return - if isinstance(from_or_join, exp.Join) and from_or_join.side: + if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause from_or_join.on(where.this, copy=False) from_or_join.set("on", simplify(from_or_join.args.get("on"))) @@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope): return outer_scope.expression.set("order", inner_scope.expression.args.get("order")) + + +def _pop_cte(inner_scope): + """ + Remove CTE from the AST. + + Args: + inner_scope (sqlglot.optimizer.scope.Scope) + """ + cte = inner_scope.expression.parent + with_ = cte.parent + if len(with_.expressions) == 1: + with_.pop() + else: + cte.pop() diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c8c2403..9a09327 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,7 +1,7 @@ from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects -from sqlglot.optimizer.merge_derived_tables import merge_derived_tables +from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates @@ -22,7 +22,7 @@ RULES = ( pushdown_predicates, optimize_joins, eliminate_subqueries, - merge_derived_tables, + merge_subqueries, quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 097ce04..5584830 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -37,7 +37,7 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left, right = scope.union + left, right = scope.union_scopes referenced_columns[left] = parent_selections referenced_columns[right] = parent_selections diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 1761228..1bbd86a 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -69,7 +69,7 @@ def ensure_schema(schema): def fs_get(table): - name = table.this.name.upper() + name = table.this.name if name.upper() == "READ_CSV": with csv_reader(table) as reader: diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index e816e10..be6cfb9 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,3 +1,4 @@ +import itertools from copy import copy from enum import Enum, auto @@ -32,10 +33,11 @@ class Scope: The inner query would have `["col1", "col2"]` for its `outer_column_list` parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent - subquery_scopes (list[Scope]): List of all child scopes for subqueries. - This does not include derived tables or CTEs. - union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be - a tuple of the left and right child scopes. + subquery_scopes (list[Scope]): List of all child scopes for subqueries + cte_scopes = (list[Scope]) List of all child scopes for CTEs + derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables + union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be + a list of the left and right child scopes. """ def __init__( @@ -52,7 +54,9 @@ class Scope: self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] - self.union = None + self.derived_table_scopes = [] + self.cte_scopes = [] + self.union_scopes = [] self.clear_cache() def clear_cache(self): @@ -197,11 +201,16 @@ class Scope: named_outputs = {e.alias_or_name for e in self.expression.expressions} - self._columns = [ - c - for c in columns + external_columns - if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) - ] + self._columns = [] + for column in columns + external_columns: + ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint) + if ( + not ancestor + or column.table + or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint)) + ): + self._columns.append(column) + return self._columns @property @@ -284,6 +293,26 @@ class Scope: return self.scope_type == ScopeType.SUBQUERY @property + def is_derived_table(self): + """Determine if this scope is a derived table""" + return self.scope_type == ScopeType.DERIVED_TABLE + + @property + def is_union(self): + """Determine if this scope is a union""" + return self.scope_type == ScopeType.UNION + + @property + def is_cte(self): + """Determine if this scope is a common table expression""" + return self.scope_type == ScopeType.CTE + + @property + def is_root(self): + """Determine if this is the root scope""" + return self.scope_type == ScopeType.ROOT + + @property def is_unnest(self): """Determine if this scope is an unnest""" return self.scope_type == ScopeType.UNNEST @@ -308,6 +337,22 @@ class Scope: self.sources.pop(name, None) self.clear_cache() + def __repr__(self): + return f"Scope<{self.expression.sql()}>" + + def traverse(self): + """ + Traverse the scope tree from this node. + + Yields: + Scope: scope instances in depth-first-search post-order + """ + for child_scope in itertools.chain( + self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes + ): + yield from child_scope.traverse() + yield self + def traverse_scope(expression): """ @@ -337,6 +382,18 @@ def traverse_scope(expression): return list(_traverse_scope(Scope(expression))) +def build_scope(expression): + """ + Build a scope tree. + + Args: + expression (exp.Expression): expression to build the scope tree for + Returns: + Scope: root scope + """ + return traverse_scope(expression)[-1] + + def _traverse_scope(scope): if isinstance(scope.expression, exp.Select): yield from _traverse_select(scope) @@ -370,13 +427,14 @@ def _traverse_union(scope): for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right - scope.union = (left, right) + scope.union_scopes = [left, right] def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} for derived_table in derived_tables: + top = None for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, @@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope + top = child_scope # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. sources[derived_table.alias] = child_scope + if scope_type == ScopeType.CTE: + scope.cte_scopes.append(top) + else: + scope.derived_table_scopes.append(top) scope.sources.update(sources) @@ -407,8 +470,6 @@ def _add_table_sources(scope): if table_name in scope.sources: # This is a reference to a parent source (e.g. a CTE), not an actual table. scope.sources[source_name] = scope.sources[table_name] - elif source_name in scope.sources: - raise OptimizeError(f"Duplicate table name: {source_name}") else: sources[source_name] = table diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 6ad6391..72bad92 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -99,7 +99,8 @@ class Parser: TokenType.SMALLMONEY, TokenType.ROWVERSION, TokenType.IMAGE, - TokenType.SQL_VARIANT, + TokenType.VARIANT, + TokenType.OBJECT, *NESTED_TYPE_TOKENS, } @@ -131,7 +132,6 @@ class Parser: TokenType.FALSE, TokenType.FIRST, TokenType.FOLLOWING, - TokenType.FOR, TokenType.FORMAT, TokenType.FUNCTION, TokenType.GENERATED, @@ -141,20 +141,26 @@ class Parser: TokenType.ISNULL, TokenType.INTERVAL, TokenType.LAZY, + TokenType.LANGUAGE, TokenType.LEADING, TokenType.LOCATION, + TokenType.MATERIALIZED, TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, TokenType.OPTIMIZE, TokenType.OPTIONS, TokenType.ORDINALITY, + TokenType.PARTITIONED_BY, TokenType.PERCENT, + TokenType.PIVOT, TokenType.PRECEDING, TokenType.RANGE, TokenType.REFERENCES, + TokenType.RETURNS, TokenType.ROWS, TokenType.SCHEMA_COMMENT, + TokenType.SEED, TokenType.SET, TokenType.SHOW, TokenType.STORED, @@ -167,6 +173,7 @@ class Parser: TokenType.TRUE, TokenType.UNBOUNDED, TokenType.UNIQUE, + TokenType.UNPIVOT, TokenType.PROPERTIES, *SUBQUERY_PREDICATES, *TYPE_TOKENS, @@ -303,6 +310,8 @@ class Parser: exp.Condition: lambda self: self._parse_conjunction(), exp.Expression: lambda self: self._parse_statement(), exp.Properties: lambda self: self._parse_properties(), + exp.Where: lambda self: self._parse_where(), + exp.Ordered: lambda self: self._parse_ordered(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -355,23 +364,21 @@ class Parser: PROPERTY_PARSERS = { TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), - TokenType.COLLATE: lambda self: self._parse_collate(), - TokenType.ENGINE: lambda self: self._parse_engine(), - TokenType.FORMAT: lambda self: self._parse_format(), TokenType.LOCATION: lambda self: self.expression( exp.LocationProperty, this=exp.Literal.string("LOCATION"), value=self._parse_string(), ), - TokenType.PARTITIONED_BY: lambda self: self.expression( - exp.PartitionedByProperty, - this=exp.Literal.string("PARTITIONED_BY"), - value=self._parse_schema(), - ), + TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), TokenType.STORED: lambda self: self._parse_stored(), - TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(), - TokenType.USING: lambda self: self._parse_table_format(), + TokenType.RETURNS: lambda self: self._parse_returns(), + TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), + TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), + TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), + TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty), + TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), + TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), } CONSTRAINT_PARSERS = { @@ -388,6 +395,7 @@ class Parser: FUNCTION_PARSERS = { "CONVERT": lambda self: self._parse_convert(), "EXTRACT": lambda self: self._parse_extract(), + "POSITION": lambda self: self._parse_position(), "SUBSTRING": lambda self: self._parse_substring(), "TRIM": lambda self: self._parse_trim(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), @@ -628,6 +636,10 @@ class Parser: replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) unique = self._match(TokenType.UNIQUE) + materialized = self._match(TokenType.MATERIALIZED) + + if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): + self._match(TokenType.TABLE) create_token = self._match_set(self.CREATABLES) and self._prev @@ -640,14 +652,15 @@ class Parser: properties = None if create_token.token_type == TokenType.FUNCTION: - this = self._parse_var() + this = self._parse_user_defined_function() + properties = self._parse_properties() if self._match(TokenType.ALIAS): - expression = self._parse_string() + expression = self._parse_select_or_expression() elif create_token.token_type == TokenType.INDEX: this = self._parse_index() elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): this = self._parse_table(schema=True) - properties = self._parse_properties(this if isinstance(this, exp.Schema) else None) + properties = self._parse_properties() if self._match(TokenType.ALIAS): expression = self._parse_select(nested=True) @@ -661,9 +674,10 @@ class Parser: temporary=temporary, replace=replace, unique=unique, + materialized=materialized, ) - def _parse_property(self, schema): + def _parse_property(self): if self._match_set(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.token_type](self) if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): @@ -673,54 +687,34 @@ class Parser: key = self._parse_var().this self._match(TokenType.EQ) - if key.upper() == "PARTITIONED_BY": - expression = exp.PartitionedByProperty - value = self._parse_schema() or self._parse_bracket(self._parse_field()) - - if schema and not isinstance(value, exp.Schema): - columns = {v.name.upper() for v in value.expressions} - partitions = [ - expression for expression in schema.expressions if expression.this.name.upper() in columns - ] - schema.set( - "expressions", - [e for e in schema.expressions if e not in partitions], - ) - value = self.expression(exp.Schema, expressions=partitions) - else: - value = self._parse_column() - expression = exp.AnonymousProperty - return self.expression( - expression, + exp.AnonymousProperty, this=exp.Literal.string(key), - value=value, + value=self._parse_column(), ) + return None - def _parse_stored(self): - self._match(TokenType.ALIAS) + def _parse_property_assignment(self, exp_class): + prop = self._prev.text self._match(TokenType.EQ) - return self.expression( - exp.FileFormatProperty, - this=exp.Literal.string("FORMAT"), - value=exp.Literal.string(self._parse_var().name), - ) + return self.expression(exp_class, this=prop, value=self._parse_var_or_string()) - def _parse_format(self): + def _parse_partitioned_by(self): self._match(TokenType.EQ) return self.expression( - exp.FileFormatProperty, - this=exp.Literal.string("FORMAT"), - value=self._parse_string() or self._parse_var(), + exp.PartitionedByProperty, + this=exp.Literal.string("PARTITIONED_BY"), + value=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_engine(self): + def _parse_stored(self): + self._match(TokenType.ALIAS) self._match(TokenType.EQ) return self.expression( - exp.EngineProperty, - this=exp.Literal.string("ENGINE"), - value=self._parse_var_or_string(), + exp.FileFormatProperty, + this=exp.Literal.string("FORMAT"), + value=exp.Literal.string(self._parse_var().name), ) def _parse_auto_increment(self): @@ -731,14 +725,6 @@ class Parser: value=self._parse_var() or self._parse_number(), ) - def _parse_collate(self): - self._match(TokenType.EQ) - return self.expression( - exp.CollateProperty, - this=exp.Literal.string("COLLATE"), - value=self._parse_var_or_string(), - ) - def _parse_schema_comment(self): self._match(TokenType.EQ) return self.expression( @@ -756,26 +742,34 @@ class Parser: default=default, ) - def _parse_table_format(self): - self._match(TokenType.EQ) + def _parse_returns(self): + is_table = self._match(TokenType.TABLE) + if is_table: + if self._match(TokenType.LT): + value = self.expression( + exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs) + ) + if not self._match(TokenType.GT): + self.raise_error("Expecting >") + else: + value = self._parse_schema("TABLE") + else: + value = self._parse_types() + return self.expression( - exp.TableFormatProperty, - this=exp.Literal.string("TABLE_FORMAT"), - value=self._parse_var_or_string(), + exp.ReturnsProperty, + this=exp.Literal.string("RETURNS"), + value=value, + is_table=is_table, ) - def _parse_properties(self, schema=None): - """ - Schema is included since if the table schema is defined and we later get a partition by expression - then we will define those columns in the partition by section and not in with the rest of the - columns - """ + def _parse_properties(self): properties = [] while True: if self._match(TokenType.WITH): self._match_l_paren() - properties.extend(self._parse_csv(lambda: self._parse_property(schema))) + properties.extend(self._parse_csv(lambda: self._parse_property())) self._match_r_paren() elif self._match(TokenType.PROPERTIES): self._match_l_paren() @@ -790,7 +784,7 @@ class Parser: ) self._match_r_paren() else: - identified_property = self._parse_property(schema) + identified_property = self._parse_property() if not identified_property: break properties.append(identified_property) @@ -1003,7 +997,7 @@ class Parser: ) def _parse_subquery(self, this): - return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) + return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias()) def _parse_query_modifiers(self, this): if not isinstance(this, self.MODIFIABLES): @@ -1134,14 +1128,18 @@ class Parser: table = (not schema and self._parse_function()) or self._parse_id_var(False) while self._match(TokenType.DOT): - catalog = db - db = table - table = self._parse_id_var() + if catalog: + # This allows nesting the table in arbitrarily many dot expressions if needed + table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) + else: + catalog = db + db = table + table = self._parse_id_var() if not table: self.raise_error("Expected table name") - this = self.expression(exp.Table, this=table, db=db, catalog=catalog) + this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()) if schema: return self._parse_schema(this=this) @@ -1199,6 +1197,7 @@ class Parser: percent = None rows = None size = None + seed = None self._match_l_paren() @@ -1220,6 +1219,11 @@ class Parser: self._match_r_paren() + if self._match(TokenType.SEED): + self._match_l_paren() + seed = self._parse_number() + self._match_r_paren() + return self.expression( exp.TableSample, method=method, @@ -1229,6 +1233,51 @@ class Parser: percent=percent, rows=rows, size=size, + seed=seed, + ) + + def _parse_pivots(self): + return list(iter(self._parse_pivot, None)) + + def _parse_pivot(self): + index = self._index + + if self._match(TokenType.PIVOT): + unpivot = False + elif self._match(TokenType.UNPIVOT): + unpivot = True + else: + return None + + expressions = [] + field = None + + if not self._match(TokenType.L_PAREN): + self._retreat(index) + return None + + if unpivot: + expressions = self._parse_csv(self._parse_column) + else: + expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function())) + + if not self._match(TokenType.FOR): + self.raise_error("Expecting FOR") + + value = self._parse_column() + + if not self._match(TokenType.IN): + self.raise_error("Expecting IN") + + field = self._parse_in(value) + + self._match_r_paren() + + return self.expression( + exp.Pivot, + expressions=expressions, + field=field, + unpivot=unpivot, ) def _parse_where(self): @@ -1384,7 +1433,7 @@ class Parser: this = self.expression(exp.In, this=this, unnest=unnest) else: self._match_l_paren() - expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression()) + expressions = self._parse_csv(self._parse_select_or_expression) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): this = self.expression(exp.In, this=this, query=expressions[0]) @@ -1577,6 +1626,9 @@ class Parser: if self._match_set(self.PRIMARY_PARSERS): return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + if self._match_pair(TokenType.DOT, TokenType.NUMBER): + return exp.Literal.number(f"0.{self._prev.text}") + if self._match(TokenType.L_PAREN): query = self._parse_select() @@ -1647,6 +1699,23 @@ class Parser: self._match_r_paren() return self._parse_window(this) + def _parse_user_defined_function(self): + this = self._parse_var() + if not self._match(TokenType.L_PAREN): + return this + expressions = self._parse_csv(self._parse_udf_kwarg) + self._match_r_paren() + return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + + def _parse_udf_kwarg(self): + this = self._parse_id_var() + kind = self._parse_types() + + if not kind: + return this + + return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind) + def _parse_lambda(self): index = self._index @@ -1672,9 +1741,10 @@ class Parser: return self._parse_alias(self._parse_limit(self._parse_order(this))) + conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions}) return self.expression( exp.Lambda, - this=self._parse_conjunction(), + this=conjunction, expressions=expressions, ) @@ -1896,6 +1966,12 @@ class Parser: to = None return self.expression(exp.Cast, this=this, to=to) + def _parse_position(self): + substr = self._parse_bitwise() + if self._match(TokenType.IN): + string = self._parse_bitwise() + return self.expression(exp.StrPosition, this=string, substr=substr) + def _parse_substring(self): # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 @@ -2155,6 +2231,9 @@ class Parser: self._match_r_paren() return expressions + def _parse_select_or_expression(self): + return self._parse_select() or self._parse_expression() + def _match(self, token_type): if not self._curr: return None @@ -2208,3 +2287,9 @@ class Parser: elif isinstance(this, exp.Identifier): this = self.expression(exp.Var, this=this.name) return this + + def _replace_lambda(self, node, lambda_variables): + if isinstance(node, exp.Column): + if node.name in lambda_variables: + return node.this + return node diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 7a50fc3..c81f0db 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -94,7 +94,8 @@ class TokenType(AutoName): SMALLMONEY = auto() ROWVERSION = auto() IMAGE = auto() - SQL_VARIANT = auto() + VARIANT = auto() + OBJECT = auto() # keywords ADD_FILE = auto() @@ -177,6 +178,7 @@ class TokenType(AutoName): IS = auto() ISNULL = auto() JOIN = auto() + LANGUAGE = auto() LATERAL = auto() LAZY = auto() LEADING = auto() @@ -185,6 +187,7 @@ class TokenType(AutoName): LIMIT = auto() LOCATION = auto() MAP = auto() + MATERIALIZED = auto() MOD = auto() NATURAL = auto() NEXT = auto() @@ -208,6 +211,7 @@ class TokenType(AutoName): PARTITION_BY = auto() PARTITIONED_BY = auto() PERCENT = auto() + PIVOT = auto() PLACEHOLDER = auto() PRECEDING = auto() PRIMARY_KEY = auto() @@ -219,12 +223,14 @@ class TokenType(AutoName): REPLACE = auto() RESPECT_NULLS = auto() REFERENCES = auto() + RETURNS = auto() RIGHT = auto() RLIKE = auto() ROLLUP = auto() ROW = auto() ROWS = auto() SCHEMA_COMMENT = auto() + SEED = auto() SELECT = auto() SEPARATOR = auto() SET = auto() @@ -246,6 +252,7 @@ class TokenType(AutoName): UNCACHE = auto() UNION = auto() UNNEST = auto() + UNPIVOT = auto() UPDATE = auto() USE = auto() USING = auto() @@ -440,6 +447,7 @@ class Tokenizer(metaclass=_Tokenizer): "FULL": TokenType.FULL, "FUNCTION": TokenType.FUNCTION, "FOLLOWING": TokenType.FOLLOWING, + "FOR": TokenType.FOR, "FOREIGN KEY": TokenType.FOREIGN_KEY, "FORMAT": TokenType.FORMAT, "FROM": TokenType.FROM, @@ -459,6 +467,7 @@ class Tokenizer(metaclass=_Tokenizer): "IS": TokenType.IS, "ISNULL": TokenType.ISNULL, "JOIN": TokenType.JOIN, + "LANGUAGE": TokenType.LANGUAGE, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, "LEADING": TokenType.LEADING, @@ -466,6 +475,7 @@ class Tokenizer(metaclass=_Tokenizer): "LIKE": TokenType.LIKE, "LIMIT": TokenType.LIMIT, "LOCATION": TokenType.LOCATION, + "MATERIALIZED": TokenType.MATERIALIZED, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NO ACTION": TokenType.NO_ACTION, @@ -473,6 +483,7 @@ class Tokenizer(metaclass=_Tokenizer): "NULL": TokenType.NULL, "NULLS FIRST": TokenType.NULLS_FIRST, "NULLS LAST": TokenType.NULLS_LAST, + "OBJECT": TokenType.OBJECT, "OFFSET": TokenType.OFFSET, "ON": TokenType.ON, "ONLY": TokenType.ONLY, @@ -488,7 +499,9 @@ class Tokenizer(metaclass=_Tokenizer): "PARTITION": TokenType.PARTITION, "PARTITION BY": TokenType.PARTITION_BY, "PARTITIONED BY": TokenType.PARTITIONED_BY, + "PARTITIONED_BY": TokenType.PARTITIONED_BY, "PERCENT": TokenType.PERCENT, + "PIVOT": TokenType.PIVOT, "PRECEDING": TokenType.PRECEDING, "PRIMARY KEY": TokenType.PRIMARY_KEY, "RANGE": TokenType.RANGE, @@ -497,11 +510,13 @@ class Tokenizer(metaclass=_Tokenizer): "REPLACE": TokenType.REPLACE, "RESPECT NULLS": TokenType.RESPECT_NULLS, "REFERENCES": TokenType.REFERENCES, + "RETURNS": TokenType.RETURNS, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, "ROLLUP": TokenType.ROLLUP, "ROW": TokenType.ROW, "ROWS": TokenType.ROWS, + "SEED": TokenType.SEED, "SELECT": TokenType.SELECT, "SET": TokenType.SET, "SHOW": TokenType.SHOW, @@ -520,6 +535,7 @@ class Tokenizer(metaclass=_Tokenizer): "TRUNCATE": TokenType.TRUNCATE, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, + "UNPIVOT": TokenType.UNPIVOT, "UNNEST": TokenType.UNNEST, "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, @@ -577,6 +593,7 @@ class Tokenizer(metaclass=_Tokenizer): "DATETIME": TokenType.DATETIME, "UNIQUE": TokenType.UNIQUE, "STRUCT": TokenType.STRUCT, + "VARIANT": TokenType.VARIANT, } WHITE_SPACE = { diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 7fc71dd..014ae00 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -12,15 +12,20 @@ def unalias_group(expression): """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { - e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias) + e.alias: (i, e.this) + for i, e in enumerate(expression.parent.expressions, start=1) + if isinstance(e, exp.Alias) } expression = expression.copy() - for col in expression.find_all(exp.Column): - alias_index = aliased_selects.get(col.name) - if not col.table and alias_index: - col.replace(exp.Literal.number(alias_index)) + top_level_expression = None + for item, parent, _ in expression.walk(bfs=False): + top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression + if isinstance(item, exp.Column) and not item.table: + alias_index, col_expression = aliased_selects.get(item.name, (None, None)) + if alias_index and top_level_expression != col_expression: + item.replace(exp.Literal.number(alias_index)) return expression diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 1337c3d..c929e59 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -236,3 +236,24 @@ class TestBigQuery(Validator): "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10", }, ) + self.validate_all( + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + write={ + "spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])", + "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + }, + ) + self.validate_all( + "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", + write={ + "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", + }, + ) + + def test_user_defined_functions(self): + self.validate_identity( + "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'" + ) + self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") + self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 4e0a3c6..e0ec824 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -13,9 +13,6 @@ from sqlglot import ( class Validator(unittest.TestCase): dialect = None - def validate(self, sql, target, **kwargs): - self.assertEqual(transpile(sql, **kwargs)[0], target) - def validate_identity(self, sql): self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) @@ -258,6 +255,7 @@ class TestDialect(Validator): "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))", "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", + "starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')", }, ) self.validate_all( @@ -266,6 +264,7 @@ class TestDialect(Validator): "duckdb": "CAST('2020-01-01' AS DATE)", "hive": "TO_DATE('2020-01-01')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + "starrocks": "TO_DATE('2020-01-01')", }, ) self.validate_all( @@ -341,6 +340,7 @@ class TestDialect(Validator): "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", "hive": "FROM_UNIXTIME(x, y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", + "starrocks": "FROM_UNIXTIME(x, y)", }, ) self.validate_all( @@ -349,6 +349,7 @@ class TestDialect(Validator): "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "hive": "FROM_UNIXTIME(x)", "presto": "FROM_UNIXTIME(x)", + "starrocks": "FROM_UNIXTIME(x)", }, ) self.validate_all( @@ -841,9 +842,19 @@ class TestDialect(Validator): }, ) self.validate_all( + "POSITION(' ' in x)", + write={ + "duckdb": "STRPOS(x, ' ')", + "postgres": "STRPOS(x, ' ')", + "presto": "STRPOS(x, ' ')", + "spark": "LOCATE(' ', x)", + }, + ) + self.validate_all( "STR_POSITION(x, 'a')", write={ "duckdb": "STRPOS(x, 'a')", + "postgres": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')", "spark": "LOCATE('a', x)", }, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f52decb..96e51df 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1,3 +1,4 @@ +from sqlglot import ErrorLevel, UnsupportedError, transpile from tests.dialects.test_dialect import Validator @@ -250,3 +251,10 @@ class TestDuckDB(Validator): "spark": "MONTH('2021-03-01')", }, ) + + with self.assertRaises(UnsupportedError): + transpile( + "SELECT a FROM b PIVOT(SUM(x) FOR y IN ('z', 'q'))", + read="duckdb", + unsupported_level=ErrorLevel.IMMEDIATE, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index a9b5168..d335921 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -127,17 +127,17 @@ class TestHive(Validator): def test_ddl(self): self.validate_all( - "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", - "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", }, ) self.validate_all( "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", write={ - "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", }, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 87a3d64..02dc1ad 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -119,3 +119,39 @@ class TestMySQL(Validator): "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", }, ) + self.validate_identity( + "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + ) + self.validate_identity( + "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + ) + self.validate_identity( + "CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + ) + + self.validate_all( + """ + CREATE TABLE `t_customer_account` ( + "id" int(11) NOT NULL AUTO_INCREMENT, + "customer_id" int(11) DEFAULT NULL COMMENT '客户id', + "bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + "account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY ("id") + ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表' + """, + write={ + "mysql": """CREATE TABLE `t_customer_account` ( + 'id' INT(11) NOT NULL AUTO_INCREMENT, + 'customer_id' INT(11) DEFAULT NULL COMMENT '客户id', + 'bank' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + 'account_no' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY('id') +) +ENGINE=InnoDB +AUTO_INCREMENT=1 +DEFAULT CHARACTER SET=utf8 +COLLATE=utf8_bin +COMMENT='客户账户表'""" + }, + pretty=True, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 96c299d..b0d9ad9 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -171,7 +171,7 @@ class TestPresto(Validator): self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, @@ -179,15 +179,15 @@ class TestPresto(Validator): self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", - "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET', X='1', Z='2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X'='1', 'Z'='2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X'='1', 'Z'='2') AS SELECT 1", }, ) self.validate_all( - "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", write={ - "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", }, @@ -195,9 +195,9 @@ class TestPresto(Validator): self.validate_all( "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", write={ - "presto": "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", - "hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", - "spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", + "presto": "CREATE TABLE x WITH (bucket_by=ARRAY['y'], bucket_count=64) AS SELECT 1 AS y", + "hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by'=ARRAY('y'), 'bucket_count'=64) AS SELECT 1 AS y", + "spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by'=ARRAY('y'), 'bucket_count'=64) AS SELECT 1 AS y", }, ) self.validate_all( @@ -217,11 +217,12 @@ class TestPresto(Validator): }, ) - self.validate( + self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", - "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - read="presto", - write="presto", + write={ + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + }, ) def test_quotes(self): diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 165f8e2..b7e39a7 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -143,6 +143,31 @@ class TestSnowflake(Validator): "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", }, ) + self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") + self.validate_all( + "SELECT RLIKE(a, b)", + write={ + "snowflake": "SELECT REGEXP_LIKE(a, b)", + }, + ) + self.validate_all( + "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", + write={ + "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", + }, + ) + self.validate_all( + "SELECT a FROM test pivot", + write={ + "snowflake": "SELECT a FROM test AS pivot", + }, + ) + self.validate_all( + "SELECT a FROM test unpivot", + write={ + "snowflake": "SELECT a FROM test AS unpivot", + }, + ) def test_null_treatment(self): self.validate_all( @@ -220,3 +245,51 @@ class TestSnowflake(Validator): "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", }, ) + + def test_semi_structured_types(self): + self.validate_identity("SELECT CAST(a AS VARIANT)") + self.validate_all( + "SELECT a::VARIANT", + write={ + "snowflake": "SELECT CAST(a AS VARIANT)", + "tsql": "SELECT CAST(a AS SQL_VARIANT)", + }, + ) + self.validate_identity("SELECT CAST(a AS ARRAY)") + self.validate_all( + "ARRAY_CONSTRUCT(0, 1, 2)", + write={ + "snowflake": "[0, 1, 2]", + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "spark": "ARRAY(0, 1, 2)", + }, + ) + self.validate_all( + "SELECT a::OBJECT", + write={ + "snowflake": "SELECT CAST(a AS OBJECT)", + }, + ) + + def test_ddl(self): + self.validate_identity( + "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" + ) + self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") + + def test_user_defined_functions(self): + self.validate_all( + "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$", + write={ + "snowflake": "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS ' SELECT 1 '", + }, + ) + self.validate_all( + "CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'", + write={ + "snowflake": "CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'", + "bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 22f6947..8377e47 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -34,7 +34,7 @@ class TestSpark(Validator): self.validate_all( "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", write={ - "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY = ARRAY['MONTHS'])", + "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", }, @@ -42,7 +42,7 @@ class TestSpark(Validator): self.validate_all( "CREATE TABLE test STORED AS PARQUET AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, @@ -56,9 +56,9 @@ class TestSpark(Validator): ) COMMENT='Test comment: blah' WITH ( - PARTITIONED_BY = ARRAY['date'], - FORMAT = 'ICEBERG', - x = '1' + PARTITIONED_BY=ARRAY['date'], + FORMAT='ICEBERG', + x='1' )""", "hive": """CREATE TABLE blah ( col_a INT @@ -69,7 +69,7 @@ PARTITIONED BY ( ) STORED AS ICEBERG TBLPROPERTIES ( - 'x' = '1' + 'x'='1' )""", "spark": """CREATE TABLE blah ( col_a INT @@ -80,7 +80,7 @@ PARTITIONED BY ( ) USING ICEBERG TBLPROPERTIES ( - 'x' = '1' + 'x'='1' )""", }, pretty=True, diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 0619eaa..6b0b39b 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -15,6 +15,14 @@ class TestTSQL(Validator): }, ) + self.validate_all( + "CONVERT(INT, CONVERT(NUMERIC, '444.75'))", + write={ + "mysql": "CAST(CAST('444.75' AS DECIMAL) AS INT)", + "tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)", + }, + ) + def test_types(self): self.validate_identity("CAST(x AS XML)") self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)") @@ -24,3 +32,13 @@ class TestTSQL(Validator): self.validate_identity("CAST(x AS IMAGE)") self.validate_identity("CAST(x AS SQL_VARIANT)") self.validate_identity("CAST(x AS BIT)") + self.validate_all( + "CAST(x AS DATETIME2)", + read={ + "": "CAST(x AS DATETIME)", + }, + write={ + "mysql": "CAST(x AS DATETIME)", + "tsql": "CAST(x AS DATETIME2)", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 1b4168c..2654be1 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -8,6 +8,7 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y 1.1E10 1.12e-10 -11.023E7 * 3 +0.2 (1 * 2) / (3 - 5) ((TRUE)) '' @@ -167,7 +168,7 @@ SELECT LEAD(a) OVER (ORDER BY b) AS a SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x SELECT X((a, b) -> a + b, z -> z) AS x -SELECT X(a -> "a" + ("z" - 1)) +SELECT X(a -> a + ("z" - 1)) SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0) SELECT test.* FROM test SELECT a AS b FROM test @@ -258,15 +259,24 @@ SELECT a FROM test TABLESAMPLE(100) SELECT a FROM test TABLESAMPLE(100 ROWS) SELECT a FROM test TABLESAMPLE BERNOULLI (50) SELECT a FROM test TABLESAMPLE SYSTEM (75) +SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) +SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1)) +SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d')) +SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q')) +SELECT a FROM test UNPIVOT(x FOR y IN (z, q)) AS x +SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE(0.1) +SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) UNPIVOT(x FOR y IN (z, q)) AS x SELECT ABS(a) FROM test SELECT AVG(a) FROM test SELECT CEIL(a) FROM test +SELECT CEIL(a, b) FROM test SELECT COUNT(a) FROM test SELECT COUNT(1) FROM test SELECT COUNT(*) FROM test SELECT COUNT(DISTINCT a) FROM test SELECT EXP(a) FROM test SELECT FLOOR(a) FROM test +SELECT FLOOR(a, b) FROM test SELECT FIRST(a) FROM test SELECT GREATEST(a, b, c) FROM test SELECT LAST(a) FROM test @@ -299,6 +309,7 @@ SELECT CAST(a AS MAP<INT, INT>) FROM test SELECT CAST(a AS TIMESTAMP) FROM test SELECT CAST(a AS DATE) FROM test SELECT CAST(a AS ARRAY<INT>) FROM test +SELECT CAST(a AS VARIANT) FROM test SELECT TRY_CAST(a AS INT) FROM test SELECT COALESCE(a, b, c) FROM test SELECT IFNULL(a, b) FROM test @@ -442,13 +453,10 @@ CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) CREATE TABLE z (a INT, PRIMARY KEY(a)) -CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' -CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' -CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 -CREATE TABLE z WITH (FORMAT='ORC', x = '2') AS SELECT 1 +CREATE TABLE z WITH (FORMAT='ORC', x='2') AS SELECT 1 CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1 -CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x = '2') AS SELECT 1 +CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x='2') AS SELECT 1 CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT, y INT)) CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT)) AS SELECT 1 CREATE TABLE z AS (WITH cte AS (SELECT 1) SELECT * FROM cte) @@ -460,6 +468,9 @@ CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f AS 'g' CREATE FUNCTION f CREATE FUNCTION f AS 'g' +CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1' +CREATE FUNCTION a() LANGUAGE sql +CREATE FUNCTION a() LANGUAGE sql RETURNS INT CREATE INDEX abc ON t (a) CREATE INDEX abc ON t (a, b, b) CREATE UNIQUE INDEX abc ON t (a, b, b) @@ -519,3 +530,4 @@ WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT ((SELECT 1) + 1) +SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql index aae5f2a..f395c0a 100644 --- a/tests/fixtures/optimizer/eliminate_subqueries.sql +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -1,42 +1,79 @@ -SELECT 1 AS x, 2 AS y -UNION ALL -SELECT 1 AS x, 2 AS y; -WITH _e_0 AS ( - SELECT - 1 AS x, - 2 AS y -) -SELECT - * -FROM _e_0 -UNION ALL -SELECT - * -FROM _e_0; - -SELECT x.id -FROM ( - SELECT * - FROM x AS x - JOIN y AS y - ON x.id = y.id -) AS x -JOIN ( - SELECT * - FROM x AS x - JOIN y AS y - ON x.id = y.id -) AS y -ON x.id = y.id; -WITH _e_0 AS ( - SELECT - * - FROM x AS x - JOIN y AS y - ON x.id = y.id -) -SELECT - x.id -FROM "_e_0" AS x -JOIN "_e_0" AS y - ON x.id = y.id; +-- No derived tables +SELECT * FROM x; +SELECT * FROM x; + +-- Unaliased derived tables +SELECT a FROM (SELECT b FROM (SELECT c FROM x)); +WITH cte AS (SELECT c FROM x), cte_2 AS (SELECT b FROM cte AS cte) SELECT a FROM cte_2 AS cte_2; + +-- Joined derived table inside nested derived table +SELECT b FROM (SELECT b FROM (SELECT b FROM x JOIN (SELECT b FROM y) AS y ON x.b = y.b)); +WITH y_2 AS (SELECT b FROM y), cte AS (SELECT b FROM x JOIN y_2 AS y ON x.b = y.b), cte_2 AS (SELECT b FROM cte AS cte) SELECT b FROM cte_2 AS cte_2; + +-- Aliased derived tables +SELECT a FROM (SELECT b FROM (SELECT c FROM x) AS y) AS z; +WITH y AS (SELECT c FROM x), z AS (SELECT b FROM y AS y) SELECT a FROM z AS z; + +-- Existing CTEs +WITH q AS (SELECT c FROM x) SELECT a FROM (SELECT b FROM q AS y) AS z; +WITH q AS (SELECT c FROM x), z AS (SELECT b FROM q AS y) SELECT a FROM z AS z; + +-- Derived table inside CTE +WITH x AS (SELECT a FROM (SELECT a FROM x) AS y) SELECT a FROM x; +WITH y AS (SELECT a FROM x), x AS (SELECT a FROM y AS y) SELECT a FROM x; + +-- Name conflicts with existing outer derived table +SELECT a FROM (SELECT b FROM (SELECT c FROM x) AS y) AS y; +WITH y AS (SELECT c FROM x), y_2 AS (SELECT b FROM y AS y) SELECT a FROM y_2 AS y; + +-- Name conflicts with outer join +SELECT a, b FROM (SELECT c FROM (SELECT d FROM x) AS x) AS y JOIN x ON x.a = y.a; +WITH x_2 AS (SELECT d FROM x), y AS (SELECT c FROM x_2 AS x) SELECT a, b FROM y AS y JOIN x ON x.a = y.a; + +-- Name conflicts with table name that is selected in another branch +SELECT * FROM (SELECT * FROM (SELECT a FROM x) AS x) AS y JOIN (SELECT * FROM x) AS z ON x.a = y.a; +WITH x_2 AS (SELECT a FROM x), y AS (SELECT * FROM x_2 AS x), z AS (SELECT * FROM x) SELECT * FROM y AS y JOIN z AS z ON x.a = y.a; + +-- Name conflicts with table alias +SELECT a FROM (SELECT a FROM (SELECT a FROM x) AS y) AS z JOIN q AS y; +WITH y AS (SELECT a FROM x), z AS (SELECT a FROM y AS y) SELECT a FROM z AS z JOIN q AS y; + +-- Name conflicts with existing CTE +WITH y AS (SELECT a FROM (SELECT a FROM x) AS y) SELECT a FROM y; +WITH y_2 AS (SELECT a FROM x), y AS (SELECT a FROM y_2 AS y) SELECT a FROM y; + +-- Union +SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y; +WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.x AS x, cte.y AS y FROM cte AS cte UNION ALL SELECT cte.x AS x, cte.y AS y FROM cte AS cte; + +-- Union of selects with derived tables +(SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y)); +WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4); + +-- Subquery +SELECT a FROM x WHERE b = (SELECT y.c FROM y); +SELECT a FROM x WHERE b = (SELECT y.c FROM y); + +-- Correlated subquery +SELECT a FROM x WHERE b = (SELECT c FROM y WHERE y.a = x.a); +SELECT a FROM x WHERE b = (SELECT c FROM y WHERE y.a = x.a); + +-- Duplicate CTE +SELECT a FROM (SELECT b FROM x) AS y JOIN (SELECT b FROM x) AS z; +WITH y AS (SELECT b FROM x) SELECT a FROM y AS y JOIN y AS z; + +-- Doubly duplicate CTE +SELECT * FROM (SELECT * FROM x JOIN (SELECT * FROM x) AS y) AS z JOIN (SELECT * FROM x JOIN (SELECT * FROM x) AS y) AS q; +WITH y AS (SELECT * FROM x), z AS (SELECT * FROM x JOIN y AS y) SELECT * FROM z AS z JOIN z AS q; + +-- Another duplicate... +SELECT x.id FROM (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) AS x JOIN (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) AS y ON x.id = y.id; +WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x_2 AS x JOIN x_2 AS y ON x.id = y.id; + +-- Root subquery +(SELECT * FROM (SELECT * FROM x)) LIMIT 1; +(WITH cte AS (SELECT * FROM x) SELECT * FROM cte AS cte) LIMIT 1; + +-- Existing duplicate CTE +WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z; +WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z; diff --git a/tests/fixtures/optimizer/merge_derived_tables.sql b/tests/fixtures/optimizer/merge_subqueries.sql index c5aa7e9..35aed3b 100644 --- a/tests/fixtures/optimizer/merge_derived_tables.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -18,6 +18,14 @@ SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; +-- Outer query has join +SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; + +# leave_tables_isolated: true +SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b; + -- Join on derived table SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; @@ -42,13 +50,9 @@ SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2 SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b; SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b; --- WHERE clause in joined derived table is merged +-- WHERE clause in joined derived table is merged to ON clause SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y; -SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1; - --- WHERE clause in outer joined derived table is merged to ON clause -SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y; -SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1; -- Comma JOIN in outer query SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y; @@ -61,3 +65,35 @@ SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z; -- (Regression) Column in ORDER BY SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1; + +-- CTE +WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x; +SELECT x.a AS a, x.b AS b FROM x AS x; + +-- CTE with outer table alias +WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z; +SELECT x.a AS a, x.b AS b FROM x AS x; + +-- Nested CTE +WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2; +SELECT x.a AS a FROM x AS x; + +-- CTE WHERE clause is merged +WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a; +SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; + +-- CTE Outer query has join +WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; + +-- CTE with inner table alias +WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z; +SELECT q.a AS a, q.b AS b FROM x AS q; + +-- Duplicate queries to CTE +WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y; +WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y; + +-- Nested CTE +SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x); +SELECT x.a AS a, x.b AS b FROM x AS x; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index f1d0f7d..0bb742b 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -65,18 +65,14 @@ WITH "cte1" AS ( SELECT "x"."a" AS "a" FROM "x" AS "x" -), "cte2" AS ( - SELECT - "cte1"."a" + 1 AS "a" - FROM "cte1" ) SELECT "cte1"."a" AS "a" FROM "cte1" UNION ALL SELECT - "cte2"."a" AS "a" -FROM "cte2"; + "cte1"."a" + 1 AS "a" +FROM "cte1"; SELECT a, SUM(b) FROM ( @@ -86,18 +82,19 @@ FROM ( ) d WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 GROUP BY a; -SELECT - "x"."a" AS "a", - SUM("y"."b") AS "_col_1" -FROM "x" AS "x" -LEFT JOIN ( +WITH "_u_0" AS ( SELECT MAX("y"."b") AS "_col_0", "y"."a" AS "_u_1" FROM "y" AS "y" GROUP BY "y"."a" -) AS "_u_0" +) +SELECT + "x"."a" AS "a", + SUM("y"."b") AS "_col_1" +FROM "x" AS "x" +LEFT JOIN "_u_0" AS "_u_0" ON "x"."a" = "_u_0"."_u_1" JOIN "y" AS "y" ON "x"."a" = "y"."a" @@ -127,3 +124,16 @@ LIMIT 1; FROM "y" AS "y" ) LIMIT 1; + +# dialect: spark +SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b; +SELECT /*+ BROADCAST(`y`) */ + `x`.`b` AS `b` +FROM `x` AS `x` +JOIN `y` AS `y` + ON `x`.`b` = `y`.`b`; + +SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; +SELECT + AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg" +FROM "x" AS "x"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 004c57c..f848e7a 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -69,6 +69,9 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; +SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; +SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x; + -------------------------------------- -- Derived tables -------------------------------------- @@ -231,3 +234,10 @@ SELECT COALESCE(x.b, y.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b WHERE COALES SELECT b FROM x JOIN y USING (b) JOIN z USING (b); SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b; + +-------------------------------------- +-- Hint with table reference +-------------------------------------- +# dialect: spark +SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b; +SELECT /*+ BROADCAST(y) */ x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b; diff --git a/tests/fixtures/optimizer/qualify_columns__invalid.sql b/tests/fixtures/optimizer/qualify_columns__invalid.sql index 056b0e9..1104b6e 100644 --- a/tests/fixtures/optimizer/qualify_columns__invalid.sql +++ b/tests/fixtures/optimizer/qualify_columns__invalid.sql @@ -5,7 +5,6 @@ SELECT z.* FROM x; SELECT x FROM x; INSERT INTO x VALUES (1, 2); SELECT a FROM x AS z JOIN y AS z; -WITH z AS (SELECT * FROM x) SELECT * FROM x AS z; SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c); SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a; SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 0b6d382..d2f10fc 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -97,19 +97,32 @@ order by p_partkey limit 100; -WITH "_e_0" AS ( +WITH "partsupp_2" AS ( SELECT "partsupp"."ps_partkey" AS "ps_partkey", "partsupp"."ps_suppkey" AS "ps_suppkey", "partsupp"."ps_supplycost" AS "ps_supplycost" FROM "partsupp" AS "partsupp" -), "_e_1" AS ( +), "region_2" AS ( SELECT "region"."r_regionkey" AS "r_regionkey", "region"."r_name" AS "r_name" FROM "region" AS "region" WHERE "region"."r_name" = 'EUROPE' +), "_u_0" AS ( + SELECT + MIN("partsupp"."ps_supplycost") AS "_col_0", + "partsupp"."ps_partkey" AS "_u_1" + FROM "partsupp_2" AS "partsupp" + CROSS JOIN "region_2" AS "region" + JOIN "nation" AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" + JOIN "supplier" AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" + GROUP BY + "partsupp"."ps_partkey" ) SELECT "supplier"."s_acctbal" AS "s_acctbal", @@ -121,25 +134,12 @@ SELECT "supplier"."s_phone" AS "s_phone", "supplier"."s_comment" AS "s_comment" FROM "part" AS "part" -LEFT JOIN ( - SELECT - MIN("partsupp"."ps_supplycost") AS "_col_0", - "partsupp"."ps_partkey" AS "_u_1" - FROM "_e_0" AS "partsupp" - CROSS JOIN "_e_1" AS "region" - JOIN "nation" AS "nation" - ON "nation"."n_regionkey" = "region"."r_regionkey" - JOIN "supplier" AS "supplier" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" - AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" - GROUP BY - "partsupp"."ps_partkey" -) AS "_u_0" +LEFT JOIN "_u_0" AS "_u_0" ON "part"."p_partkey" = "_u_0"."_u_1" -CROSS JOIN "_e_1" AS "region" +CROSS JOIN "region_2" AS "region" JOIN "nation" AS "nation" ON "nation"."n_regionkey" = "region"."r_regionkey" -JOIN "_e_0" AS "partsupp" +JOIN "partsupp_2" AS "partsupp" ON "part"."p_partkey" = "partsupp"."ps_partkey" JOIN "supplier" AS "supplier" ON "supplier"."s_nationkey" = "nation"."n_nationkey" @@ -193,12 +193,12 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" + AND "orders"."o_orderdate" < '1995-03-15' JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "lineitem"."l_shipdate" > '1995-03-15' WHERE "customer"."c_mktsegment" = 'BUILDING' - AND "lineitem"."l_shipdate" > '1995-03-15' - AND "orders"."o_orderdate" < '1995-03-15' GROUP BY "lineitem"."l_orderkey", "orders"."o_orderdate", @@ -232,11 +232,7 @@ group by o_orderpriority order by o_orderpriority; -SELECT - "orders"."o_orderpriority" AS "o_orderpriority", - COUNT(*) AS "order_count" -FROM "orders" AS "orders" -LEFT JOIN ( +WITH "_u_0" AS ( SELECT "lineitem"."l_orderkey" AS "l_orderkey" FROM "lineitem" AS "lineitem" @@ -244,7 +240,12 @@ LEFT JOIN ( "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" GROUP BY "lineitem"."l_orderkey" -) AS "_u_0" +) +SELECT + "orders"."o_orderpriority" AS "o_orderpriority", + COUNT(*) AS "order_count" +FROM "orders" AS "orders" +LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."l_orderkey" = "orders"."o_orderkey" WHERE "orders"."o_orderdate" < CAST('1993-10-01' AS DATE) @@ -290,7 +291,10 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" -CROSS JOIN "region" AS "region" + AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) +JOIN "region" AS "region" + ON "region"."r_name" = 'ASIA' JOIN "nation" AS "nation" ON "nation"."n_regionkey" = "region"."r_regionkey" JOIN "supplier" AS "supplier" @@ -299,10 +303,6 @@ JOIN "supplier" AS "supplier" JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_suppkey" = "supplier"."s_suppkey" -WHERE - "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) - AND "region"."r_name" = 'ASIA' GROUP BY "nation"."n_name" ORDER BY @@ -371,7 +371,7 @@ order by supp_nation, cust_nation, l_year; -WITH "_e_0" AS ( +WITH "n1" AS ( SELECT "nation"."n_nationkey" AS "n_nationkey", "nation"."n_name" AS "n_name" @@ -389,14 +389,15 @@ SELECT )) AS "revenue" FROM "supplier" AS "supplier" JOIN "lineitem" AS "lineitem" - ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" JOIN "orders" AS "orders" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" JOIN "customer" AS "customer" ON "customer"."c_custkey" = "orders"."o_custkey" -JOIN "_e_0" AS "n1" +JOIN "n1" AS "n1" ON "supplier"."s_nationkey" = "n1"."n_nationkey" -JOIN "_e_0" AS "n2" +JOIN "n1" AS "n2" ON "customer"."c_nationkey" = "n2"."n_nationkey" AND ( "n1"."n_name" = 'FRANCE' @@ -406,8 +407,6 @@ JOIN "_e_0" AS "n2" "n1"."n_name" = 'GERMANY' OR "n2"."n_name" = 'GERMANY' ) -WHERE - "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) GROUP BY "n1"."n_name", "n2"."n_name", @@ -469,13 +468,15 @@ SELECT 1 - "lineitem"."l_discount" )) AS "mkt_share" FROM "part" AS "part" -CROSS JOIN "region" AS "region" +JOIN "region" AS "region" + ON "region"."r_name" = 'AMERICA' JOIN "nation" AS "nation" ON "nation"."n_regionkey" = "region"."r_regionkey" JOIN "customer" AS "customer" ON "customer"."c_nationkey" = "nation"."n_nationkey" JOIN "orders" AS "orders" ON "orders"."o_custkey" = "customer"."c_custkey" + AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "part"."p_partkey" = "lineitem"."l_partkey" @@ -484,9 +485,7 @@ JOIN "supplier" AS "supplier" JOIN "nation" AS "nation_2" ON "supplier"."s_nationkey" = "nation_2"."n_nationkey" WHERE - "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) - AND "part"."p_type" = 'ECONOMY ANODIZED STEEL' - AND "region"."r_name" = 'AMERICA' + "part"."p_type" = 'ECONOMY ANODIZED STEEL' GROUP BY EXTRACT(year FROM "orders"."o_orderdate") ORDER BY @@ -604,14 +603,13 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" + AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "lineitem"."l_returnflag" = 'R' JOIN "nation" AS "nation" ON "customer"."c_nationkey" = "nation"."n_nationkey" -WHERE - "lineitem"."l_returnflag" = 'R' - AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) GROUP BY "customer"."c_custkey", "customer"."c_name", @@ -654,12 +652,12 @@ group by ) order by value desc; -WITH "_e_0" AS ( +WITH "supplier_2" AS ( SELECT "supplier"."s_suppkey" AS "s_suppkey", "supplier"."s_nationkey" AS "s_nationkey" FROM "supplier" AS "supplier" -), "_e_1" AS ( +), "nation_2" AS ( SELECT "nation"."n_nationkey" AS "n_nationkey", "nation"."n_name" AS "n_name" @@ -671,9 +669,9 @@ SELECT "partsupp"."ps_partkey" AS "ps_partkey", SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" FROM "partsupp" AS "partsupp" -JOIN "_e_0" AS "supplier" +JOIN "supplier_2" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" -JOIN "_e_1" AS "nation" +JOIN "nation_2" AS "nation" ON "supplier"."s_nationkey" = "nation"."n_nationkey" GROUP BY "partsupp"."ps_partkey" @@ -682,9 +680,9 @@ HAVING SELECT SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" FROM "partsupp" AS "partsupp" - JOIN "_e_0" AS "supplier" + JOIN "supplier_2" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" - JOIN "_e_1" AS "nation" + JOIN "nation_2" AS "nation" ON "supplier"."s_nationkey" = "nation"."n_nationkey" ) ORDER BY @@ -737,13 +735,12 @@ SELECT END) AS "low_line_count" FROM "orders" AS "orders" JOIN "lineitem" AS "lineitem" - ON "orders"."o_orderkey" = "lineitem"."l_orderkey" -WHERE - "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" + ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') + AND "orders"."o_orderkey" = "lineitem"."l_orderkey" GROUP BY "lineitem"."l_shipmode" ORDER BY @@ -772,10 +769,7 @@ group by order by custdist desc, c_count desc; -SELECT - "c_orders"."c_count" AS "c_count", - COUNT(*) AS "custdist" -FROM ( +WITH "c_orders" AS ( SELECT COUNT("orders"."o_orderkey") AS "c_count" FROM "customer" AS "customer" @@ -784,7 +778,11 @@ FROM ( AND NOT "orders"."o_comment" LIKE '%special%requests%' GROUP BY "customer"."c_custkey" -) AS "c_orders" +) +SELECT + "c_orders"."c_count" AS "c_count", + COUNT(*) AS "custdist" +FROM "c_orders" AS "c_orders" GROUP BY "c_orders"."c_count" ORDER BY @@ -920,13 +918,7 @@ order by p_brand, p_type, p_size; -SELECT - "part"."p_brand" AS "p_brand", - "part"."p_type" AS "p_type", - "part"."p_size" AS "p_size", - COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt" -FROM "partsupp" AS "partsupp" -LEFT JOIN ( +WITH "_u_0" AS ( SELECT "supplier"."s_suppkey" AS "s_suppkey" FROM "supplier" AS "supplier" @@ -934,15 +926,22 @@ LEFT JOIN ( "supplier"."s_comment" LIKE '%Customer%Complaints%' GROUP BY "supplier"."s_suppkey" -) AS "_u_0" +) +SELECT + "part"."p_brand" AS "p_brand", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size", + COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt" +FROM "partsupp" AS "partsupp" +LEFT JOIN "_u_0" AS "_u_0" ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey" JOIN "part" AS "part" - ON "part"."p_partkey" = "partsupp"."ps_partkey" -WHERE - "_u_0"."s_suppkey" IS NULL - AND "part"."p_brand" <> 'Brand#45' + ON "part"."p_brand" <> 'Brand#45' + AND "part"."p_partkey" = "partsupp"."ps_partkey" AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9) AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%' +WHERE + "_u_0"."s_suppkey" IS NULL GROUP BY "part"."p_brand", "part"."p_type", @@ -973,24 +972,25 @@ where where l_partkey = p_partkey ); -SELECT - SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly" -FROM "lineitem" AS "lineitem" -JOIN "part" AS "part" - ON "part"."p_partkey" = "lineitem"."l_partkey" -LEFT JOIN ( +WITH "_u_0" AS ( SELECT 0.2 * AVG("lineitem"."l_quantity") AS "_col_0", "lineitem"."l_partkey" AS "_u_1" FROM "lineitem" AS "lineitem" GROUP BY "lineitem"."l_partkey" -) AS "_u_0" +) +SELECT + SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly" +FROM "lineitem" AS "lineitem" +JOIN "part" AS "part" + ON "part"."p_brand" = 'Brand#23' + AND "part"."p_container" = 'MED BOX' + AND "part"."p_partkey" = "lineitem"."l_partkey" +LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."_u_1" = "part"."p_partkey" WHERE "lineitem"."l_quantity" < "_u_0"."_col_0" - AND "part"."p_brand" = 'Brand#23' - AND "part"."p_container" = 'MED BOX' AND NOT "_u_0"."_u_1" IS NULL; -------------------------------------- @@ -1030,6 +1030,16 @@ order by o_orderdate limit 100; +WITH "_u_0" AS ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey" + FROM "lineitem" AS "lineitem" + GROUP BY + "lineitem"."l_orderkey", + "lineitem"."l_orderkey" + HAVING + SUM("lineitem"."l_quantity") > 300 +) SELECT "customer"."c_name" AS "c_name", "customer"."c_custkey" AS "c_custkey", @@ -1040,16 +1050,7 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" -LEFT JOIN ( - SELECT - "lineitem"."l_orderkey" AS "l_orderkey" - FROM "lineitem" AS "lineitem" - GROUP BY - "lineitem"."l_orderkey", - "lineitem"."l_orderkey" - HAVING - SUM("lineitem"."l_quantity") > 300 -) AS "_u_0" +LEFT JOIN "_u_0" AS "_u_0" ON "orders"."o_orderkey" = "_u_0"."l_orderkey" JOIN "lineitem" AS "lineitem" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -1200,38 +1201,34 @@ where and n_name = 'CANADA' order by s_name; -SELECT - "supplier"."s_name" AS "s_name", - "supplier"."s_address" AS "s_address" -FROM "supplier" AS "supplier" -LEFT JOIN ( +WITH "_u_0" AS ( + SELECT + 0.5 * SUM("lineitem"."l_quantity") AS "_col_0", + "lineitem"."l_partkey" AS "_u_1", + "lineitem"."l_suppkey" AS "_u_2" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) + GROUP BY + "lineitem"."l_partkey", + "lineitem"."l_suppkey" +), "_u_3" AS ( + SELECT + "part"."p_partkey" AS "p_partkey" + FROM "part" AS "part" + WHERE + "part"."p_name" LIKE 'forest%' + GROUP BY + "part"."p_partkey" +), "_u_4" AS ( SELECT "partsupp"."ps_suppkey" AS "ps_suppkey" FROM "partsupp" AS "partsupp" - LEFT JOIN ( - SELECT - 0.5 * SUM("lineitem"."l_quantity") AS "_col_0", - "lineitem"."l_partkey" AS "_u_1", - "lineitem"."l_suppkey" AS "_u_2" - FROM "lineitem" AS "lineitem" - WHERE - "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) - GROUP BY - "lineitem"."l_partkey", - "lineitem"."l_suppkey" - ) AS "_u_0" + LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."_u_1" = "partsupp"."ps_partkey" AND "_u_0"."_u_2" = "partsupp"."ps_suppkey" - LEFT JOIN ( - SELECT - "part"."p_partkey" AS "p_partkey" - FROM "part" AS "part" - WHERE - "part"."p_name" LIKE 'forest%' - GROUP BY - "part"."p_partkey" - ) AS "_u_3" + LEFT JOIN "_u_3" AS "_u_3" ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" WHERE "partsupp"."ps_availqty" > "_u_0"."_col_0" @@ -1240,13 +1237,18 @@ LEFT JOIN ( AND NOT "_u_3"."p_partkey" IS NULL GROUP BY "partsupp"."ps_suppkey" -) AS "_u_4" +) +SELECT + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address" +FROM "supplier" AS "supplier" +LEFT JOIN "_u_4" AS "_u_4" ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey" JOIN "nation" AS "nation" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" + ON "nation"."n_name" = 'CANADA' + AND "supplier"."s_nationkey" = "nation"."n_nationkey" WHERE - "nation"."n_name" = 'CANADA' - AND NOT "_u_4"."ps_suppkey" IS NULL + NOT "_u_4"."ps_suppkey" IS NULL ORDER BY "s_name"; @@ -1294,22 +1296,14 @@ order by s_name limit 100; -SELECT - "supplier"."s_name" AS "s_name", - COUNT(*) AS "numwait" -FROM "supplier" AS "supplier" -JOIN "lineitem" AS "lineitem" - ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" -LEFT JOIN ( +WITH "_u_0" AS ( SELECT "l2"."l_orderkey" AS "l_orderkey", ARRAY_AGG("l2"."l_suppkey") AS "_u_1" FROM "lineitem" AS "l2" GROUP BY "l2"."l_orderkey" -) AS "_u_0" - ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey" -LEFT JOIN ( +), "_u_2" AS ( SELECT "l3"."l_orderkey" AS "l_orderkey", ARRAY_AGG("l3"."l_suppkey") AS "_u_3" @@ -1318,20 +1312,29 @@ LEFT JOIN ( "l3"."l_receiptdate" > "l3"."l_commitdate" GROUP BY "l3"."l_orderkey" -) AS "_u_2" +) +SELECT + "supplier"."s_name" AS "s_name", + COUNT(*) AS "numwait" +FROM "supplier" AS "supplier" +JOIN "lineitem" AS "lineitem" + ON "lineitem"."l_receiptdate" > "lineitem"."l_commitdate" + AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" +LEFT JOIN "_u_0" AS "_u_0" + ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey" +LEFT JOIN "_u_2" AS "_u_2" ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey" JOIN "orders" AS "orders" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" + AND "orders"."o_orderstatus" = 'F' JOIN "nation" AS "nation" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" + ON "nation"."n_name" = 'SAUDI ARABIA' + AND "supplier"."s_nationkey" = "nation"."n_nationkey" WHERE ( "_u_2"."l_orderkey" IS NULL OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey") ) - AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate" - AND "nation"."n_name" = 'SAUDI ARABIA' - AND "orders"."o_orderstatus" = 'F' AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey") AND NOT "_u_0"."l_orderkey" IS NULL GROUP BY @@ -1381,18 +1384,19 @@ group by cntrycode order by cntrycode; -SELECT - SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", - COUNT(*) AS "numcust", - SUM("customer"."c_acctbal") AS "totacctbal" -FROM "customer" AS "customer" -LEFT JOIN ( +WITH "_u_0" AS ( SELECT "orders"."o_custkey" AS "_u_1" FROM "orders" AS "orders" GROUP BY "orders"."o_custkey" -) AS "_u_0" +) +SELECT + SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", + COUNT(*) AS "numcust", + SUM("customer"."c_acctbal") AS "totacctbal" +FROM "customer" AS "customer" +LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."_u_1" = "customer"."c_custkey" WHERE "_u_0"."_u_1" IS NULL diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 5ed74f4..19a7451 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -264,22 +264,3 @@ CREATE TABLE "t_customer_account" ( "account_no" VARCHAR(100) ); -CREATE TABLE "t_customer_account" ( - "id" int(11) NOT NULL AUTO_INCREMENT, - "customer_id" int(11) DEFAULT NULL COMMENT '客户id', - "bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', - "account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', - PRIMARY KEY ("id") -) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表'; -CREATE TABLE "t_customer_account" ( - "id" INT(11) NOT NULL AUTO_INCREMENT, - "customer_id" INT(11) DEFAULT NULL COMMENT '客户id', - "bank" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', - "account_no" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', - PRIMARY KEY("id") -) -ENGINE=InnoDB -AUTO_INCREMENT=1 -DEFAULT CHARACTER SET=utf8 -COLLATE=utf8_bin -COMMENT='客户账户表'; diff --git a/tests/test_build.py b/tests/test_build.py index 18c0e47..b5d657c 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -270,7 +270,7 @@ class TestBuild(unittest.TestCase): lambda: parse_one("SELECT * FROM y") .assert_is(exp.Select) .ctas("foo.x", properties={"format": "parquet", "y": "2"}), - "CREATE TABLE foo.x STORED AS PARQUET TBLPROPERTIES ('y' = '2') AS SELECT * FROM y", + "CREATE TABLE foo.x STORED AS PARQUET TBLPROPERTIES ('y'='2') AS SELECT * FROM y", "hive", ), (lambda: and_("x=1", "y=1"), "x = 1 AND y = 1"), @@ -308,6 +308,18 @@ class TestBuild(unittest.TestCase): lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"), "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", ), + ( + lambda: exp.update("tbl", {"x": None, "y": {"x": 1}}), + "UPDATE tbl SET x = NULL, y = MAP('x', 1)", + ), + ( + lambda: exp.update("tbl", {"x": 1}, where="y > 0"), + "UPDATE tbl SET x = 1 WHERE y > 0", + ), + ( + lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), + "UPDATE tbl SET x = 1 FROM tbl2", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 64ad02d..cc41307 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -27,6 +27,8 @@ class TestExpressions(unittest.TestCase): parse_one("ROW() OVER (partition BY y)"), ) self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) + self.assertEqual(exp.Table(pivots=[]), exp.Table()) + self.assertNotEqual(exp.Table(pivots=[None]), exp.Table()) def test_find(self): expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") @@ -280,6 +282,19 @@ class TestExpressions(unittest.TestCase): expression.find(exp.Table).replace(parse_one("y")) self.assertEqual(expression.sql(), "SELECT c, b FROM y") + def test_pop(self): + expression = parse_one("SELECT a, b FROM x") + expression.find(exp.Column).pop() + self.assertEqual(expression.sql(), "SELECT b FROM x") + expression.find(exp.Column).pop() + self.assertEqual(expression.sql(), "SELECT FROM x") + expression.pop() + self.assertEqual(expression.sql(), "SELECT FROM x") + + expression = parse_one("WITH x AS (SELECT a FROM x) SELECT * FROM x") + expression.find(exp.With).pop() + self.assertEqual(expression.sql(), "SELECT * FROM x") + def test_walk(self): expression = parse_one("SELECT * FROM (SELECT * FROM x)") self.assertEqual(len(list(expression.walk())), 9) @@ -316,6 +331,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("MAX(a)"), exp.Max) self.assertIsInstance(parse_one("MIN(a)"), exp.Min) self.assertIsInstance(parse_one("MONTH(a)"), exp.Month) + self.assertIsInstance(parse_one("POSITION(' ' IN a)"), exp.StrPosition) self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow) self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow) self.assertIsInstance(parse_one("QUANTILE(a, 0.90)"), exp.Quantile) @@ -420,7 +436,7 @@ class TestExpressions(unittest.TestCase): exp.Properties.from_dict( { "FORMAT": "parquet", - "PARTITIONED_BY": [exp.to_identifier("a"), exp.to_identifier("b")], + "PARTITIONED_BY": (exp.to_identifier("a"), exp.to_identifier("b")), "custom": 1, "TABLE_FORMAT": exp.to_identifier("test_format"), "ENGINE": None, @@ -444,4 +460,17 @@ class TestExpressions(unittest.TestCase): ), ) - self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": {"key": "value"}}) + self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": object}) + + def test_convert(self): + for value, expected in [ + (1, "1"), + ("1", "'1'"), + (None, "NULL"), + (True, "TRUE"), + ((1, "2", None), "(1, '2', NULL)"), + ([1, "2", None], "ARRAY(1, '2', NULL)"), + ({"x": None}, "MAP('x', NULL)"), + ]: + with self.subTest(value): + self.assertEqual(exp.convert(value).sql(), expected) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 102e141..8d4aecc 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,9 +1,11 @@ import unittest +from functools import partial -from sqlglot import optimizer, parse_one, table +from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError +from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import build_scope, traverse_scope from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures @@ -27,11 +29,17 @@ class TestOptimizer(unittest.TestCase): } def check_file(self, file, func, pretty=False, **kwargs): - for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"): + for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): dialect = meta.get("dialect") - with self.subTest(sql): + leave_tables_isolated = meta.get("leave_tables_isolated") + + func_kwargs = {**kwargs} + if leave_tables_isolated is not None: + func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1") + + with self.subTest(f"{i}, {sql}"): self.assertEqual( - func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect), + func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect), expected, ) @@ -123,21 +131,20 @@ class TestOptimizer(unittest.TestCase): optimizer.optimize_joins.optimize_joins, ) - def test_eliminate_subqueries(self): - self.check_file( - "eliminate_subqueries", - optimizer.eliminate_subqueries.eliminate_subqueries, - pretty=True, + def test_merge_subqueries(self): + optimize = partial( + optimizer.optimize, + rules=[ + optimizer.qualify_tables.qualify_tables, + optimizer.qualify_columns.qualify_columns, + optimizer.merge_subqueries.merge_subqueries, + ], ) - def test_merge_derived_tables(self): - def optimize(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) - expression = optimizer.merge_derived_tables.merge_derived_tables(expression) - return expression + self.check_file("merge_subqueries", optimize, schema=self.schema) - self.check_file("merge_derived_tables", optimize, schema=self.schema) + def test_eliminate_subqueries(self): + self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) @@ -257,17 +264,73 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ON s.b = r.b WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) """ - scopes = traverse_scope(parse_one(sql)) - self.assertEqual(len(scopes), 5) - self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") - self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") - self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) - - self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) - self.assertEqual(len(scopes[4].columns), 6) - self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) - self.assertEqual(scopes[4].source_columns("q"), []) - self.assertEqual(len(scopes[4].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()): + self.assertEqual(len(scopes), 5) + self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") + self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") + self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") + self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) + self.assertEqual(len(scopes[4].columns), 6) + self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) + self.assertEqual(scopes[4].source_columns("q"), []) + self.assertEqual(len(scopes[4].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + + def test_literal_type_annotation(self): + tests = { + "SELECT 5": exp.DataType.Type.INT, + "SELECT 5.3": exp.DataType.Type.DOUBLE, + "SELECT 'bla'": exp.DataType.Type.VARCHAR, + "5": exp.DataType.Type.INT, + "5.3": exp.DataType.Type.DOUBLE, + "'bla'": exp.DataType.Type.VARCHAR, + } + + for sql, target_type in tests.items(): + expression = parse_one(sql) + annotated_expression = annotate_types(expression) + + self.assertEqual(annotated_expression.find(exp.Literal).type, target_type) + + def test_boolean_type_annotation(self): + tests = { + "SELECT TRUE": exp.DataType.Type.BOOLEAN, + "FALSE": exp.DataType.Type.BOOLEAN, + } + + for sql, target_type in tests.items(): + expression = parse_one(sql) + annotated_expression = annotate_types(expression) + + self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type) + + def test_cast_type_annotation(self): + expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))") + annotate_types(expression) + + self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) + self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) + + def test_cache_annotation(self): + expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") + annotated_expression = annotate_types(expression) + + self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT) + + def test_binary_annotation(self): + expression = parse_one("SELECT 0.0 + (2 + 3)") + annotate_types(expression) + + expression = expression.expressions[0] + + self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) diff --git a/tests/test_parser.py b/tests/test_parser.py index 9e430e2..4c46531 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -21,6 +21,11 @@ class TestParser(unittest.TestCase): self.assertIsNotNone(parse_one("date").find(exp.Column)) + def test_float(self): + self.assertEqual(parse_one(".2"), parse_one("0.2")) + self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) + self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)")) + def test_table(self): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] self.assertEqual(tables, ["a", "b.c", "d"]) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 2030109..1928d2c 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,11 +6,32 @@ from sqlglot.transforms import unalias_group class TestTime(unittest.TestCase): def validate(self, transform, sql, target): - self.assertEqual(parse_one(sql).transform(transform).sql(), target) + with self.subTest(sql): + self.assertEqual(parse_one(sql).transform(transform).sql(), target) def test_unalias_group(self): self.validate( unalias_group, "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4", - "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, 2, x.c, 4", + "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4", + ) + self.validate( + unalias_group, + "SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), aliased_last", + "SELECT TO_DATE(the_date) AS the_date, CUSTOM_UDF(other_col) AS other_col, last_col AS aliased_last, COUNT(*) AS the_count FROM x GROUP BY TO_DATE(the_date), CUSTOM_UDF(other_col), 3", + ) + self.validate( + unalias_group, + "SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))", + "SELECT SOME_UDF(TO_DATE(the_date)) AS the_date, COUNT(*) AS the_count FROM x GROUP BY SOME_UDF(TO_DATE(the_date))", + ) + self.validate( + unalias_group, + "SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY new_date", + "SELECT SOME_UDF(TO_DATE(the_date)) AS new_date, COUNT(*) AS the_count FROM x GROUP BY 1", + ) + self.validate( + unalias_group, + "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", + "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", ) |