From f818ab3b896d52e874634b7c4db3533078c1887f Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 10 Oct 2022 13:29:05 +0200 Subject: Merging upstream version 6.3.1. Signed-off-by: Daniel Baumann --- CHANGELOG.md | 22 ++ sqlglot/__init__.py | 5 +- sqlglot/dialects/bigquery.py | 6 +- sqlglot/dialects/clickhouse.py | 33 ++- sqlglot/dialects/dialect.py | 17 +- sqlglot/dialects/hive.py | 40 +--- sqlglot/dialects/oracle.py | 29 +++ sqlglot/dialects/postgres.py | 14 +- sqlglot/dialects/snowflake.py | 5 +- sqlglot/dialects/spark.py | 21 +- sqlglot/executor/python.py | 1 + sqlglot/expressions.py | 222 +++++++++++++++++++- sqlglot/generator.py | 29 ++- sqlglot/optimizer/annotate_types.py | 158 +++++++++++++-- sqlglot/optimizer/merge_subqueries.py | 44 +++- sqlglot/optimizer/pushdown_predicates.py | 38 ++-- sqlglot/optimizer/qualify_columns.py | 8 +- sqlglot/optimizer/schema.py | 63 +++++- sqlglot/optimizer/scope.py | 20 +- sqlglot/optimizer/simplify.py | 8 +- sqlglot/parser.py | 121 +++++++---- sqlglot/tokens.py | 8 + tests/dialects/test_bigquery.py | 18 ++ tests/dialects/test_clickhouse.py | 10 +- tests/dialects/test_dialect.py | 45 ++++- tests/dialects/test_hive.py | 18 +- tests/dialects/test_postgres.py | 1 + tests/dialects/test_snowflake.py | 32 +++ tests/dialects/test_spark.py | 58 ++++++ tests/fixtures/identity.sql | 6 + tests/fixtures/optimizer/merge_subqueries.sql | 168 +++++++++++---- tests/fixtures/optimizer/optimizer.sql | 140 ++++++++++++- tests/fixtures/optimizer/qualify_columns.sql | 43 ++-- .../optimizer/qualify_columns__with_invisible.sql | 35 ++++ tests/fixtures/optimizer/simplify.sql | 6 + tests/fixtures/optimizer/tpc-h/tpc-h.sql | 13 +- tests/helpers.py | 8 + tests/test_build.py | 63 +++++- tests/test_expressions.py | 22 ++ tests/test_optimizer.py | 225 ++++++++++++++++++--- tests/test_transpile.py | 2 +- 41 files changed, 1558 insertions(+), 267 deletions(-) create mode 100644 tests/fixtures/optimizer/qualify_columns__with_invisible.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index f16fc70..aee697a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,28 @@ Changelog ========= +v6.3.0 +------ + +Changes: + +- New: Snowflake [table literals](https://docs.snowflake.com/en/sql-reference/literals-table.html) + +- New: Anti and semi joins + +- New: Vacuum as a command + +- New: Stored procedures + +- New: Reweriting derived tables as CTES + +- Improvement: Various clickhouse improvements + +- Improvement: Optimizer predicate pushdown + +- Breaking: DATE\_DIFF default renamed to DATEDIFF + + v6.2.0 ------ diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 1f7b28c..0228bdd 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -8,7 +8,9 @@ from sqlglot.expressions import ( and_, column, condition, + except_, from_, + intersect, maybe_parse, not_, or_, @@ -16,11 +18,12 @@ from sqlglot.expressions import ( subquery, ) from sqlglot.expressions import table_ as table +from sqlglot.expressions import union from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.2.8" +__version__ = "6.3.1" pretty = False diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 40298e7..86e46cf 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -135,6 +135,7 @@ class BigQuery(Dialect): exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), + exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.ILike: no_ilike_sql, exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -172,12 +173,11 @@ class BigQuery(Dialect): exp.AnonymousProperty, } + EXPLICIT_UNION = True + def in_unnest_op(self, unnest): return self.sql(unnest) - def union_op(self, expression): - return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" - def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 55dad7a..da5c856 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,10 +1,16 @@ from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.generator import Generator -from sqlglot.parser import Parser +from sqlglot.helper import csv +from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer, TokenType +def _lower_func(sql): + index = sql.index("(") + return sql[:index].lower() + sql[index:] + + class ClickHouse(Dialect): normalize_functions = None null_ordering = "nulls_are_last" @@ -14,17 +20,23 @@ class ClickHouse(Dialect): KEYWORDS = { **Tokenizer.KEYWORDS, - "NULLABLE": TokenType.NULLABLE, "FINAL": TokenType.FINAL, + "DATETIME64": TokenType.DATETIME, "INT8": TokenType.TINYINT, "INT16": TokenType.SMALLINT, "INT32": TokenType.INT, "INT64": TokenType.BIGINT, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, + "TUPLE": TokenType.STRUCT, } class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "MAP": parse_var_map, + } + def _parse_table(self, schema=False): this = super()._parse_table(schema) @@ -39,10 +51,25 @@ class ClickHouse(Dialect): TYPE_MAPPING = { **Generator.TYPE_MAPPING, exp.DataType.Type.NULLABLE: "Nullable", + exp.DataType.Type.DATETIME: "DateTime64", + exp.DataType.Type.MAP: "Map", + exp.DataType.Type.ARRAY: "Array", + exp.DataType.Type.STRUCT: "Tuple", + exp.DataType.Type.TINYINT: "Int8", + exp.DataType.Type.SMALLINT: "Int16", + exp.DataType.Type.INT: "Int32", + exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.FLOAT: "Float32", + exp.DataType.Type.DOUBLE: "Float64", } TRANSFORMS = { **Generator.TRANSFORMS, exp.Array: inline_array_sql, + exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), + exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } + + EXPLICIT_UNION = True diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 98dc330..f7c6cb5 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -77,7 +77,6 @@ class Dialect(metaclass=_Dialect): alias_post_tablesample = False normalize_functions = "upper" null_ordering = "nulls_are_small" - wrap_derived_values = True date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" @@ -170,7 +169,6 @@ class Dialect(metaclass=_Dialect): "alias_post_tablesample": self.alias_post_tablesample, "normalize_functions": self.normalize_functions, "null_ordering": self.null_ordering, - "wrap_derived_values": self.wrap_derived_values, **opts, } ) @@ -271,6 +269,21 @@ def struct_extract_sql(self, expression): return f"{this}.{struct_key}" +def var_map_sql(self, expression): + keys = expression.args["keys"] + values = expression.args["values"] + + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + self.unsupported("Cannot convert array columns into map.") + return f"MAP({self.sql(keys)}, {self.sql(values)})" + + args = [] + for key, value in zip(keys.expressions, values.expressions): + args.append(self.sql(key)) + args.append(self.sql(value)) + return f"MAP({csv(*args)})" + + def format_time_lambda(exp_class, dialect, default=None): """Helper used for time expressions. diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 7a27bb3..55d7bcc 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -11,40 +11,14 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, struct_extract_sql, + var_map_sql, ) from sqlglot.generator import Generator from sqlglot.helper import csv, list_get -from sqlglot.parser import Parser +from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer -def _parse_map(args): - keys = [] - values = [] - for i in range(0, len(args), 2): - keys.append(args[i]) - values.append(args[i + 1]) - return HiveMap( - keys=exp.Array(expressions=keys), - values=exp.Array(expressions=values), - ) - - -def _map_sql(self, expression): - keys = expression.args["keys"] - values = expression.args["values"] - - if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): - self.unsupported("Cannot convert array columns into map use SparkSQL instead.") - return f"MAP({self.sql(keys)}, {self.sql(values)})" - - args = [] - for key, value in zip(keys.expressions, values.expressions): - args.append(self.sql(key)) - args.append(self.sql(value)) - return f"MAP({csv(*args)})" - - def _array_sort(self, expression): if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") @@ -122,10 +96,6 @@ def _index_sql(self, expression): return f"{this} ON TABLE {table} {columns}" -class HiveMap(exp.Map): - is_var_len_args = True - - class Hive(Dialect): alias_post_tablesample = True @@ -206,7 +176,7 @@ class Hive(Dialect): position=list_get(args, 2), ), "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), - "MAP": _parse_map, + "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, @@ -245,8 +215,8 @@ class Hive(Dialect): exp.Join: _unnest_to_explode_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), - exp.Map: _map_sql, - HiveMap: _map_sql, + exp.Map: var_map_sql, + exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 91e30b2..8041ff0 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -10,6 +10,32 @@ def _limit_sql(self, expression): class Oracle(Dialect): + # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 + # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes + time_mapping = { + "AM": "%p", # Meridian indicator with or without periods + "A.M.": "%p", # Meridian indicator with or without periods + "PM": "%p", # Meridian indicator with or without periods + "P.M.": "%p", # Meridian indicator with or without periods + "D": "%u", # Day of week (1-7) + "DAY": "%A", # name of day + "DD": "%d", # day of month (1-31) + "DDD": "%j", # day of year (1-366) + "DY": "%a", # abbreviated name of day + "HH": "%I", # Hour of day (1-12) + "HH12": "%I", # alias for HH + "HH24": "%H", # Hour of day (0-23) + "IW": "%V", # Calendar week of year (1-52 or 1-53), as defined by the ISO 8601 standard + "MI": "%M", # Minute (0-59) + "MM": "%m", # Month (01-12; January = 01) + "MON": "%b", # Abbreviated name of month + "MONTH": "%B", # Name of month + "SS": "%S", # Second (0-59) + "WW": "%W", # Week of year (1-53) + "YY": "%y", # 15 + "YYYY": "%Y", # 2015 + } + class Generator(Generator): TYPE_MAPPING = { **Generator.TYPE_MAPPING, @@ -30,6 +56,9 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", } def query_modifiers(self, expression, *sqls): diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index aaa07a1..731e28e 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -118,13 +118,22 @@ def _serial_to_generated(expression): return expression +def _to_timestamp(args): + # TO_TIMESTAMP accepts either a single double argument or (text, text) + if len(args) == 1 and args[0].is_number: + # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE + return exp.UnixToTime.from_arg_list(args) + # https://www.postgresql.org/docs/current/functions-formatting.html + return format_time_lambda(exp.StrToTime, "postgres")(args) + + class Postgres(Dialect): null_ordering = "nulls_are_large" time_format = "'YYYY-MM-DD HH24:MI:SS'" time_mapping = { "AM": "%p", "PM": "%p", - "D": "%w", # 1-based day of week + "D": "%u", # 1-based day of week "DD": "%d", # day of month "DDD": "%j", # zero padded day of year "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres @@ -172,7 +181,7 @@ class Postgres(Dialect): FUNCTIONS = { **Parser.FUNCTIONS, - "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), + "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } @@ -211,4 +220,5 @@ class Postgres(Dialect): exp.TableSample: no_tablesample_sql, exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, + exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index fb2d900..19a427c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -121,6 +121,7 @@ class Snowflake(Dialect): FUNC_TOKENS = { *Parser.FUNC_TOKENS, TokenType.RLIKE, + TokenType.TABLE, } COLUMN_OPERATORS = { @@ -143,7 +144,7 @@ class Snowflake(Dialect): SINGLE_TOKENS = { **Tokenizer.SINGLE_TOKENS, - "$": TokenType.DOLLAR, # needed to break for quotes + "$": TokenType.PARAMETER, } KEYWORDS = { @@ -164,6 +165,8 @@ class Snowflake(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: _unix_to_time, exp.Array: inline_array_sql, + exp.StrPosition: rename_func("POSITION"), + exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", } diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index e8da07a..95a7ab4 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -4,8 +4,9 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, rename_func, ) -from sqlglot.dialects.hive import Hive, HiveMap +from sqlglot.dialects.hive import Hive from sqlglot.helper import list_get +from sqlglot.parser import Parser def _create_sql(self, e): @@ -47,8 +48,6 @@ def _unix_to_time(self, expression): class Spark(Hive): - wrap_derived_values = False - class Parser(Hive.Parser): FUNCTIONS = { **Hive.Parser.FUNCTIONS, @@ -78,8 +77,19 @@ class Spark(Hive): "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } - class Generator(Hive.Generator): + FUNCTION_PARSERS = { + **Parser.FUNCTION_PARSERS, + "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), + "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), + "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), + "MERGE": lambda self: self._parse_join_hint("MERGE"), + "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), + "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), + "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), + "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), + } + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", @@ -102,8 +112,9 @@ class Spark(Hive): exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", - HiveMap: _map_sql, } + WRAP_DERIVED_VALUES = False + class Tokenizer(Hive.Tokenizer): HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 610aa4b..8ef6cf0 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -326,6 +326,7 @@ class Python(Dialect): exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, exp.And: lambda self, e: self.binary(e, "and"), + exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Cast: _cast_py, exp.Column: _column_py, exp.EQ: lambda self, e: self.binary(e, "=="), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8cdacce..f2ffd12 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -508,7 +508,69 @@ class DerivedTable(Expression): return [select.alias_or_name for select in self.selects] -class UDTF(DerivedTable): +class Unionable: + def union(self, expression, distinct=True, dialect=None, **opts): + """ + Builds a UNION expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Union: the Union expression. + """ + return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + def intersect(self, expression, distinct=True, dialect=None, **opts): + """ + Builds an INTERSECT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Intersect: the Intersect expression + """ + return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + def except_(self, expression, distinct=True, dialect=None, **opts): + """ + Builds an EXCEPT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Except: the Except expression + """ + return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + +class UDTF(DerivedTable, Unionable): pass @@ -518,6 +580,10 @@ class Annotation(Expression): "expression": True, } + @property + def alias(self): + return self.expression.alias_or_name + class Cache(Expression): arg_types = { @@ -700,6 +766,10 @@ class Hint(Expression): arg_types = {"expressions": True} +class JoinHint(Expression): + arg_types = {"this": True, "expressions": True} + + class Identifier(Expression): arg_types = {"this": True, "quoted": False} @@ -971,7 +1041,7 @@ class Tuple(Expression): arg_types = {"expressions": False} -class Subqueryable: +class Subqueryable(Unionable): def subquery(self, alias=None, copy=True): """ Convert this expression to an aliased expression that can be used as a Subquery. @@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression): return self.expressions -class Subquery(DerivedTable): +class Subquery(DerivedTable, Unionable): arg_types = { "this": True, "alias": False, @@ -1731,7 +1801,7 @@ class Parameter(Expression): class Placeholder(Expression): - arg_types = {} + arg_types = {"this": False} class Null(Condition): @@ -1791,6 +1861,8 @@ class DataType(Expression): IMAGE = auto() VARIANT = auto() OBJECT = auto() + NULL = auto() + UNKNOWN = auto() # Sentinel value, useful for type annotation @classmethod def build(cls, dtype, **kwargs): @@ -2007,7 +2079,7 @@ class Distinct(Expression): class In(Predicate): - arg_types = {"this": True, "expressions": False, "query": False, "unnest": False} + arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} class TimeUnit(Expression): @@ -2377,6 +2449,11 @@ class Map(Func): arg_types = {"keys": True, "values": True} +class VarMap(Func): + arg_types = {"keys": True, "values": True} + is_var_len_args = True + + class Max(AggFunc): pass @@ -2449,7 +2526,7 @@ class Substring(Func): class StrPosition(Func): - arg_types = {"this": True, "substr": True, "position": False} + arg_types = {"substr": True, "this": True, "position": False} class StrToDate(Func): @@ -2785,6 +2862,81 @@ def _wrap_operator(expression): return expression +def union(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one UNION expression. + + Example: + >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Union: the syntax tree for the UNION expression. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Union(this=left, expression=right, distinct=distinct) + + +def intersect(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one INTERSECT expression. + + Example: + >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Intersect: the syntax tree for the INTERSECT expression. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Intersect(this=left, expression=right, distinct=distinct) + + +def except_(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one EXCEPT expression. + + Example: + >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Except: the syntax tree for the EXCEPT statement. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Except(this=left, expression=right, distinct=distinct) + + def select(*expressions, dialect=None, **opts): """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): If an Expression instance is passed, this is used as-is. alias (str or Identifier): the alias name to use. If the name has special characters it is quoted. - table (boolean): create a table alias, default false + table (bool): create a table alias, default false dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): alias = to_identifier(alias, quoted=quoted) alias = TableAlias(this=alias) if table else alias - if "alias" in exp.arg_types: + if "alias" in exp.arg_types and not isinstance(exp, Window): exp = exp.copy() exp.set("alias", alias) return exp @@ -3138,6 +3290,60 @@ def column_table_names(expression): return list(dict.fromkeys(column.table for column in expression.find_all(Column))) +def table_name(table): + """Get the full name of a table as a string. + + Args: + table (exp.Table | str): Table expression node or string. + + Examples: + >>> from sqlglot import exp, parse_one + >>> table_name(parse_one("select * from a.b.c").find(exp.Table)) + 'a.b.c' + + Returns: + str: the table name + """ + + table = maybe_parse(table, into=Table) + + return ".".join( + part + for part in ( + table.text("catalog"), + table.text("db"), + table.name, + ) + if part + ) + + +def replace_tables(expression, mapping): + """Replace all tables in expression according to the mapping. + + Args: + expression (sqlglot.Expression): Expression node to be transformed and replaced + mapping (Dict[str, str]): Mapping of table names + + Examples: + >>> from sqlglot import exp, parse_one + >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() + 'SELECT * FROM "c"' + + Returns: + The mapped expression + """ + + def _replace_tables(node): + if isinstance(node, Table): + new_name = mapping.get(table_name(node)) + if new_name: + return table_(*reversed(new_name.split(".")), quoted=True) + return node + + return expression.transform(_replace_tables) + + TRUE = Boolean(this=True) FALSE = Boolean(this=False) NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 8b356f3..b7e295d 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -48,8 +48,9 @@ class Generator: TRANSFORMS = { exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", - exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -57,7 +58,12 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } + # whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True + # always do union distinct or union all + EXPLICIT_UNION = False + # wrap derived values in parens, usually standard but spark doesn't support it + WRAP_DERIVED_VALUES = True TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", @@ -101,7 +107,6 @@ class Generator: "unsupported_messages", "null_ordering", "max_unsupported", - "wrap_derived_values", "_indent", "_replace_backslash", "_escaped_quote_end", @@ -130,7 +135,6 @@ class Generator: null_ordering=None, max_unsupported=3, leading_comma=False, - wrap_derived_values=True, ): import sqlglot @@ -154,7 +158,6 @@ class Generator: self.unsupported_messages = [] self.max_unsupported = max_unsupported self.null_ordering = null_ordering - self.wrap_derived_values = wrap_derived_values self._indent = indent self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end @@ -595,7 +598,7 @@ class Generator: if not alias: return f"VALUES{self.seg('')}{args}" alias = f" AS {alias}" if alias else alias - if self.wrap_derived_values: + if self.WRAP_DERIVED_VALUES: return f"(VALUES{self.seg('')}{args}){alias}" return f"VALUES{self.seg('')}{args}{alias}" @@ -779,8 +782,8 @@ class Generator: def parameter_sql(self, expression): return f"@{self.sql(expression, 'this')}" - def placeholder_sql(self, *_): - return "?" + def placeholder_sql(self, expression): + return f":{expression.name}" if expression.name else "?" def subquery_sql(self, expression): alias = self.sql(expression, "alias") @@ -803,7 +806,9 @@ class Generator: ) def union_op(self, expression): - return f"UNION{'' if expression.args.get('distinct') else ' ALL'}" + kind = " DISTINCT" if self.EXPLICIT_UNION else "" + kind = kind if expression.args.get("distinct") else " ALL" + return f"UNION{kind}" def unnest_sql(self, expression): args = self.expressions(expression, flat=True) @@ -940,10 +945,13 @@ class Generator: def in_sql(self, expression): query = expression.args.get("query") unnest = expression.args.get("unnest") + field = expression.args.get("field") if query: in_sql = self.wrap(query) elif unnest: in_sql = self.in_unnest_op(unnest) + elif field: + in_sql = self.sql(field) else: in_sql = f"({self.expressions(expression, flat=True)})" return f"{self.sql(expression, 'this')} IN {in_sql}" @@ -1178,3 +1186,8 @@ class Generator: this = self.sql(expression, "this") kind = self.sql(expression, "kind") return f"{this} {kind}" + + def joinhint_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"{this}({expressions})" diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 3f5f089..a2cef37 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,16 +1,20 @@ from sqlglot import exp from sqlglot.helper import ensure_list, subclasses +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import Scope, traverse_scope def annotate_types(expression, schema=None, annotators=None, coerces_to=None): """ Recursively infer & annotate types in an expression syntax tree against a schema. + Assumes that we've already executed the optimizer's qualify_columns step. - (TODO -- replace this with a better example after adding some functionality) Example: >>> import sqlglot - >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) - >>> annotated_expression.type + >>> schema = {"y": {"cola": "SMALLINT"}} + >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" + >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) + >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" Args: @@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): sqlglot.Expression: expression annotated with types """ + schema = ensure_schema(schema) + return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) @@ -35,10 +41,81 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_cast(expr), - exp.DataType: lambda self, expr: self._annotate_data_type(expr), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_boolean(expr), + exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), + exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html @@ -97,43 +174,82 @@ class TypeAnnotator: }, } + TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + def __init__(self, schema=None, annotators=None, coerces_to=None): self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO def annotate(self, expression): + if isinstance(expression, self.TRAVERSABLES): + for scope in traverse_scope(expression): + subscope_selects = { + name: {select.alias_or_name: select for select in source.selects} + for name, source in scope.sources.items() + if isinstance(source, Scope) + } + + # First annotate the current scope's column references + for col in scope.columns: + source = scope.sources[col.table] + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + else: + col.type = subscope_selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + + return self._maybe_annotate(expression) # This takes care of non-traversable expressions + + def _maybe_annotate(self, expression): if not isinstance(expression, exp.Expression): return None + if expression.type: + return expression # We've already inferred the expression's type + annotator = self.annotators.get(expression.__class__) - return annotator(self, expression) if annotator else self._annotate_args(expression) + return ( + annotator(self, expression) + if annotator + else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) + ) def _annotate_args(self, expression): for value in expression.args.values(): for v in ensure_list(value): - self.annotate(v) + self._maybe_annotate(v) return expression - def _annotate_cast(self, expression): - expression.type = expression.args["to"].this - return self._annotate_args(expression) - - def _annotate_data_type(self, expression): - expression.type = expression.this - return self._annotate_args(expression) - def _maybe_coerce(self, type1, type2): + # We propagate the NULL / UNKNOWN types upwards if found + if exp.DataType.Type.NULL in (type1, type2): + return exp.DataType.Type.NULL + if exp.DataType.Type.UNKNOWN in (type1, type2): + return exp.DataType.Type.UNKNOWN + return type2 if type2 in self.coerces_to[type1] else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - if isinstance(expression, (exp.Condition, exp.Predicate)): + left_type = expression.left.type + right_type = expression.right.type + + if isinstance(expression, (exp.And, exp.Or)): + if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: + expression.type = exp.DataType.Type.NULL + elif exp.DataType.Type.NULL in (left_type, right_type): + expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + else: + expression.type = exp.DataType.Type.BOOLEAN + elif isinstance(expression, (exp.Condition, exp.Predicate)): expression.type = exp.DataType.Type.BOOLEAN else: - expression.type = self._maybe_coerce(expression.left.type, expression.right.type) + expression.type = self._maybe_coerce(left_type, right_type) return expression @@ -157,6 +273,6 @@ class TypeAnnotator: return expression - def _annotate_boolean(self, expression): - expression.type = exp.DataType.Type.BOOLEAN - return expression + def _annotate_with_type(self, expression, target_type): + expression.type = target_type + return self._annotate_args(expression) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index d29c22b..3e435f5 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - { "joins", "where", "order", + "hint", } @@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: inner_select = inner_scope.expression.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): - from_or_join = table.find_ancestor(exp.From, exp.Join) - + from_or_join = table.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): node_to_replace = table if isinstance(node_to_replace.parent, exp.Alias): node_to_replace = node_to_replace.parent alias = node_to_replace.alias else: alias = table.name + _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, node_to_replace, alias) _merge_expressions(outer_scope, inner_scope, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) return expression @@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): alias = subquery.alias_or_name - from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) @@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated): +def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. @@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): outer_scope (Scope) inner_select (exp.Select) leave_tables_isolated (bool) + from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ @@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) + and not ( + isinstance(from_or_join, exp.Join) + and inner_select.args.get("where") + and from_or_join.side in {"FULL", "LEFT", "RIGHT"} + ) + and not ( + isinstance(from_or_join, exp.From) + and inner_select.args.get("where") + and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + ) ) @@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ new_subquery = inner_scope.expression.args.get("from").expressions[0] node_to_replace.replace(new_subquery) + for join_hint in outer_scope.join_hints: + tables = join_hint.find_all(exp.Table) + for table in tables: + if table.alias_or_name == node_to_replace.alias_or_name: + new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery + table.set("this", exp.to_identifier(new_table.alias_or_name)) outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope): outer_scope.expression.set("order", inner_scope.expression.args.get("order")) +def _merge_hints(outer_scope, inner_scope): + inner_scope_hint = inner_scope.expression.args.get("hint") + if not inner_scope_hint: + return + outer_scope_hint = outer_scope.expression.args.get("hint") + if outer_scope_hint: + for hint_expression in inner_scope_hint.expressions: + outer_scope_hint.append("expressions", hint_expression) + else: + outer_scope.expression.set("hint", inner_scope_hint) + + def _pop_cte(inner_scope): """ Remove CTE from the AST. diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index a070d70..9c8d71d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from sqlglot import exp from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import traverse_scope @@ -20,22 +22,30 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - for scope in reversed(traverse_scope(expression)): + scope_ref_count = defaultdict(lambda: 0) + scopes = traverse_scope(expression) + scopes.reverse() + + for scope in scopes: + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + for scope in scopes: select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources) + pushdown(where.this, scope.selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself for join in select.args.get("joins") or []: name = join.this.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression -def pushdown(condition, sources): +def pushdown(condition, sources, scope_ref_count): if not condition: return @@ -45,17 +55,17 @@ def pushdown(condition, sources): predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: - pushdown_cnf(predicates, sources) + pushdown_cnf(predicates, sources, scope_ref_count) else: - pushdown_dnf(predicates, sources) + pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope): +def pushdown_cnf(predicates, scope, scope_ref_count): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ for predicate in predicates: - for node in nodes_for_predicate(predicate, scope).values(): + for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): predicate.replace(exp.TRUE) node.on(predicate, copy=False) @@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def pushdown_dnf(predicates, scope): +def pushdown_dnf(predicates, scope, scope_ref_count): """ If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form. @@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope): # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) for table in sorted(pushdown_tables): for predicate in predicates: - nodes = nodes_for_predicate(predicate, scope) + nodes = nodes_for_predicate(predicate, scope, scope_ref_count) if table not in nodes: continue @@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def nodes_for_predicate(predicate, sources): +def nodes_for_predicate(predicate, sources, scope_ref_count): nodes = {} tables = exp.column_table_names(predicate) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) @@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources): if node and where_condition: node = node.find_ancestor(exp.Join, exp.From) - # a node can reference a CTE which should be push down + # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): node = source.expression @@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: - if not node.args.get("group"): + # we can't push down predicates to select statements if they are referenced in + # multiple places. + if not node.args.get("group") and scope_ref_count[id(source)] < 2: nodes[table] = node return nodes diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 72ce256..7d77ef1 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -31,8 +31,8 @@ def qualify_columns(expression, schema): _pop_table_column_aliases(scope.derived_tables) _expand_using(scope, resolver) _expand_group_by(scope, resolver) - _expand_order_by(scope) _qualify_columns(scope, resolver) + _expand_order_by(scope) if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) @@ -235,7 +235,7 @@ def _expand_stars(scope, resolver): for table in tables: if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") - columns = resolver.get_source_columns(table) + columns = resolver.get_source_columns(table, only_visible=True) table_id = id(table) for name in columns: if name not in except_columns.get(table_id, set()): @@ -332,7 +332,7 @@ class _Resolver: self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) return self._all_columns - def get_source_columns(self, name): + def get_source_columns(self, name, only_visible=False): """Resolve the source columns for a given source `name`""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") @@ -342,7 +342,7 @@ class _Resolver: # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): try: - return self.schema.column_names(source) + return self.schema.column_names(source, only_visible) except Exception as e: raise OptimizeError(str(e)) from e diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 1bbd86a..d7743c9 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -9,16 +9,28 @@ class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def column_names(self, table): + def column_names(self, table, only_visible=False): """ Get the column names for a table. - Args: table (sqlglot.expressions.Table): Table expression instance + only_visible (bool): Whether to include invisible columns Returns: list[str]: list of column names """ + @abc.abstractmethod + def get_column_type(self, table, column): + """ + Get the exp.DataType type of a column in the schema. + + Args: + table (sqlglot.expressions.Table): The source table. + column (sqlglot.expressions.Column): The target column. + Returns: + sqlglot.expressions.DataType.Type: The resulting column type. + """ + class MappingSchema(Schema): """ @@ -29,10 +41,19 @@ class MappingSchema(Schema): 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} + visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema): + def __init__(self, schema, visible=None, dialect=None): self.schema = schema + self.visible = visible + self.dialect = dialect + self._type_mapping_cache = {} depth = _dict_depth(schema) @@ -49,7 +70,7 @@ class MappingSchema(Schema): self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) - def column_names(self, table): + def column_names(self, table, only_visible=False): if not isinstance(table.this, exp.Identifier): return fs_get(table) @@ -58,7 +79,39 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + if not only_visible or not self.visible: + return columns + + visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) + return [col for col in columns if col in visible] + + def get_column_type(self, table, column): + try: + schema_type = self.schema.get(table.name, {}).get(column.name).upper() + return self._convert_type(schema_type) + except: + raise OptimizeError(f"Failed to get type for column {column.sql()}") + + def _convert_type(self, schema_type): + """ + Convert a type represented as a string to the corresponding exp.DataType.Type object. + + Args: + schema_type (str): The type we want to convert. + Returns: + sqlglot.expressions.DataType.Type: The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + try: + self._type_mapping_cache[schema_type] = exp.maybe_parse( + schema_type, into=exp.DataType, dialect=self.dialect + ).this + except AttributeError: + raise OptimizeError(f"Failed to convert type {schema_type}") + + return self._type_mapping_cache[schema_type] def ensure_schema(schema): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 6332cdd..89de517 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -68,6 +68,7 @@ class Scope: self._selected_sources = None self._columns = None self._external_columns = None + self._join_hints = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -85,14 +86,17 @@ class Scope: self._subqueries = [] self._derived_tables = [] self._raw_columns = [] + self._join_hints = [] for node, parent, _ in self.walk(bfs=False): if node is self.expression: continue elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) - elif isinstance(node, exp.Table): + elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): self._tables.append(node) + elif isinstance(node, exp.JoinHint): + self._join_hints.append(node) elif isinstance(node, exp.UDTF): self._derived_tables.append(node) elif isinstance(node, exp.CTE): @@ -246,7 +250,7 @@ class Scope: table only becomes a selected source if it's included in a FROM or JOIN clause. Returns: - dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes + dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes """ if self._selected_sources is None: referenced_names = [] @@ -310,6 +314,18 @@ class Scope: self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns + @property + def join_hints(self): + """ + Hints that exist in the scope that reference tables + + Returns: + list[exp.JoinHint]: Join hints that are referenced within the scope + """ + if self._join_hints is None: + return [] + return self._join_hints + def source_columns(self, source_name): """ Get all columns in the current scope for a particular source. diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 319e6b6..c077906 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -56,12 +56,16 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): + if isinstance(expression.this, exp.Null): + return NULL if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + if isinstance(condition, exp.Null): + return NULL if always_true(expression.this): return FALSE if expression.this == FALSE: @@ -95,10 +99,10 @@ def simplify_connectors(expression): return left if isinstance(expression, exp.And): - if NULL in (left, right): - return NULL if FALSE in (left, right): return FALSE + if NULL in (left, right): + return NULL if always_true(left) and always_true(right): return TRUE if always_true(left): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5f20afc..c29e520 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType logger = logging.getLogger("sqlglot") +def parse_var_map(args): + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + return exp.VarMap( + keys=exp.Array(expressions=keys), + values=exp.Array(expressions=values), + ) + + class Parser: """ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` @@ -48,6 +60,7 @@ class Parser: start=exp.Literal.number(1), length=exp.Literal.number(10), ), + "VAR_MAP": parse_var_map, } NO_PAREN_FUNCTIONS = { @@ -117,6 +130,7 @@ class Parser: TokenType.VAR, TokenType.ALTER, TokenType.ALWAYS, + TokenType.ANTI, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, @@ -164,6 +178,7 @@ class Parser: TokenType.ROWS, TokenType.SCHEMA_COMMENT, TokenType.SEED, + TokenType.SEMI, TokenType.SET, TokenType.SHOW, TokenType.STABLE, @@ -273,6 +288,8 @@ class Parser: TokenType.INNER, TokenType.OUTER, TokenType.CROSS, + TokenType.SEMI, + TokenType.ANTI, } COLUMN_OPERATORS = { @@ -318,6 +335,8 @@ class Parser: exp.Properties: lambda self: self._parse_properties(), exp.Where: lambda self: self._parse_where(), exp.Ordered: lambda self: self._parse_ordered(), + exp.Having: lambda self: self._parse_having(), + exp.With: lambda self: self._parse_with(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -338,7 +357,6 @@ class Parser: TokenType.NULL: lambda *_: exp.Null(), TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), - TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), @@ -910,7 +928,20 @@ class Parser: return self.expression(exp.Tuple, expressions=expressions) def _parse_select(self, nested=False, table=False): - if self._match(TokenType.SELECT): + cte = self._parse_with() + if cte: + this = self._parse_statement() + + if not this: + self.raise_error("Failed to parse any statement following CTE") + return cte + + if "with" in this.arg_types: + this.set("with", cte) + else: + self.raise_error(f"{this.key} does not support CTE") + this = cte + elif self._match(TokenType.SELECT): hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -938,39 +969,6 @@ class Parser: if from_: this.set("from", from_) self._parse_query_modifiers(this) - elif self._match(TokenType.WITH): - recursive = self._match(TokenType.RECURSIVE) - - expressions = [] - - while True: - expressions.append(self._parse_cte()) - - if not self._match(TokenType.COMMA): - break - - cte = self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ) - this = self._parse_statement() - - if not this: - self.raise_error("Failed to parse any statement following CTE") - return cte - - if "with" in this.arg_types: - this.set( - "with", - self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ), - ) - else: - self.raise_error(f"{this.key} does not support CTE") elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) @@ -986,6 +984,26 @@ class Parser: return self._parse_set_operations(this) if this else None + def _parse_with(self): + if not self._match(TokenType.WITH): + return None + + recursive = self._match(TokenType.RECURSIVE) + + expressions = [] + + while True: + expressions.append(self._parse_cte()) + + if not self._match(TokenType.COMMA): + break + + return self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ) + def _parse_cte(self): alias = self._parse_table_alias() if not alias or not alias.this: @@ -1485,8 +1503,7 @@ class Parser: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) - else: - self._match_l_paren() + elif self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_select_or_expression) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): @@ -1495,6 +1512,9 @@ class Parser: this = self.expression(exp.In, this=this, expressions=expressions) self._match_r_paren() + else: + this = self.expression(exp.In, this=this, field=self._parse_field()) + return this def _parse_between(self, this): @@ -1591,7 +1611,7 @@ class Parser: elif nested: expressions = self._parse_csv(self._parse_types) else: - expressions = self._parse_csv(self._parse_number) + expressions = self._parse_csv(self._parse_type) if not expressions: self._retreat(index) @@ -1706,7 +1726,7 @@ class Parser: def _parse_field(self, any_token=False): return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) - def _parse_function(self): + def _parse_function(self, functions=None): if not self._curr: return None @@ -1742,7 +1762,9 @@ class Parser: self._match_r_paren() return this - function = self.FUNCTIONS.get(upper) + if functions is None: + functions = self.FUNCTIONS + function = functions.get(upper) args = self._parse_csv(self._parse_lambda) if function: @@ -2025,10 +2047,20 @@ class Parser: return self.expression(exp.Cast, this=this, to=to) def _parse_position(self): - substr = self._parse_bitwise() + args = self._parse_csv(self._parse_bitwise) + if self._match(TokenType.IN): - string = self._parse_bitwise() - return self.expression(exp.StrPosition, this=string, substr=substr) + args.append(self._parse_bitwise()) + + # Note: we're parsing in order needle, haystack, position + this = exp.StrPosition.from_arg_list(args) + self.validate_expression(this, args) + + return this + + def _parse_join_hint(self, func_name): + args = self._parse_csv(self._parse_table) + return exp.JoinHint(this=func_name.upper(), expressions=args) def _parse_substring(self): # Postgres supports the form: substring(string [from int] [for int]) @@ -2247,6 +2279,9 @@ class Parser: def _parse_placeholder(self): if self._match(TokenType.PLACEHOLDER): return exp.Placeholder() + elif self._match(TokenType.COLON): + self._advance() + return exp.Placeholder(this=self._prev.text) return None def _parse_except(self): diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 39bf421..17c038c 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -104,6 +104,7 @@ class TokenType(AutoName): ALL = auto() ALTER = auto() ANALYZE = auto() + ANTI = auto() ANY = auto() ARRAY = auto() ASC = auto() @@ -236,6 +237,7 @@ class TokenType(AutoName): SCHEMA_COMMENT = auto() SEED = auto() SELECT = auto() + SEMI = auto() SEPARATOR = auto() SET = auto() SHOW = auto() @@ -262,6 +264,7 @@ class TokenType(AutoName): USE = auto() USING = auto() VALUES = auto() + VACUUM = auto() VIEW = auto() VOLATILE = auto() WHEN = auto() @@ -406,6 +409,7 @@ class Tokenizer(metaclass=_Tokenizer): "ALTER": TokenType.ALTER, "ANALYZE": TokenType.ANALYZE, "AND": TokenType.AND, + "ANTI": TokenType.ANTI, "ANY": TokenType.ANY, "ASC": TokenType.ASC, "AS": TokenType.ALIAS, @@ -528,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer): "ROWS": TokenType.ROWS, "SEED": TokenType.SEED, "SELECT": TokenType.SELECT, + "SEMI": TokenType.SEMI, "SET": TokenType.SET, "SHOW": TokenType.SHOW, "SOME": TokenType.SOME, @@ -551,6 +556,7 @@ class Tokenizer(metaclass=_Tokenizer): "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, + "VACUUM": TokenType.VACUUM, "VALUES": TokenType.VALUES, "VIEW": TokenType.VIEW, "VOLATILE": TokenType.VOLATILE, @@ -577,6 +583,7 @@ class Tokenizer(metaclass=_Tokenizer): "INT8": TokenType.BIGINT, "DECIMAL": TokenType.DECIMAL, "MAP": TokenType.MAP, + "NULLABLE": TokenType.NULLABLE, "NUMBER": TokenType.DECIMAL, "NUMERIC": TokenType.DECIMAL, "FIXED": TokenType.DECIMAL, @@ -629,6 +636,7 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.SHOW, TokenType.TRUNCATE, TokenType.USE, + TokenType.VACUUM, } # handle numeric literals like in hive (3L = BIGINT) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 7110eac..8921924 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -152,6 +152,10 @@ class TestBigQuery(Validator): "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" ) + self.validate_identity( + "SELECT item, purchases, LAST_VALUE(item) OVER (item_window ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce WINDOW item_window AS (ORDER BY purchases)" + ) + self.validate_identity( "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", ) @@ -222,6 +226,20 @@ class TestBigQuery(Validator): "spark": "DATE_ADD(CURRENT_DATE, 1)", }, ) + self.validate_all( + "DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', DAY)", + write={ + "bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), DAY)", + "mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))", + }, + ) + self.validate_all( + "DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', MINUTE)", + write={ + "bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), MINUTE)", + "mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))", + }, + ) self.validate_all( "CURRENT_DATE('UTC')", write={ diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index e5b1516..715bf10 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -8,6 +8,8 @@ class TestClickhouse(Validator): self.validate_identity("dictGet(x, 'y')") self.validate_identity("SELECT * FROM x FINAL") self.validate_identity("SELECT * FROM x AS y FINAL") + self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))") + self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -20,6 +22,12 @@ class TestClickhouse(Validator): self.validate_all( "CAST(1 AS NULLABLE(Int64))", write={ - "clickhouse": "CAST(1 AS Nullable(BIGINT))", + "clickhouse": "CAST(1 AS Nullable(Int64))", + }, + ) + self.validate_all( + "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", + write={ + "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", }, ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index a9a313c..53edb42 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -81,6 +81,24 @@ class TestDialect(Validator): "starrocks": "CAST(a AS STRING)", }, ) + self.validate_all( + "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", + write={ + "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", + }, + ) + self.validate_all( + "CAST(ARRAY(1, 2) AS ARRAY)", + write={ + "clickhouse": "CAST([1, 2] AS Array(Int8))", + }, + ) + self.validate_all( + "CAST((1, 2) AS STRUCT)", + write={ + "clickhouse": "CAST((1, 2) AS Tuple(a Int8, b Int16, c Int32, d Int64))", + }, + ) self.validate_all( "CAST(a AS DATETIME)", write={ @@ -170,7 +188,7 @@ class TestDialect(Validator): "CAST(a AS DOUBLE)", write={ "bigquery": "CAST(a AS FLOAT64)", - "clickhouse": "CAST(a AS DOUBLE)", + "clickhouse": "CAST(a AS Float64)", "duckdb": "CAST(a AS DOUBLE)", "mysql": "CAST(a AS DOUBLE)", "hive": "CAST(a AS DOUBLE)", @@ -234,6 +252,8 @@ class TestDialect(Validator): write={ "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "hive": "CAST('2020-01-01' AS TIMESTAMP)", + "oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", + "postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", "redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", @@ -245,6 +265,8 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%y')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%y')", + "oracle": "TO_TIMESTAMP(x, 'YY')", + "postgres": "TO_TIMESTAMP(x, 'YY')", "redshift": "TO_TIMESTAMP(x, 'YY')", "spark": "TO_TIMESTAMP(x, 'yy')", }, @@ -288,6 +310,8 @@ class TestDialect(Validator): write={ "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", + "oracle": "TO_CHAR(x, 'YYYY-MM-DD')", + "postgres": "TO_CHAR(x, 'YYYY-MM-DD')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')", "redshift": "TO_CHAR(x, 'YYYY-MM-DD')", }, @@ -348,6 +372,8 @@ class TestDialect(Validator): write={ "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "hive": "FROM_UNIXTIME(x)", + "oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)", + "postgres": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "starrocks": "FROM_UNIXTIME(x)", }, @@ -704,6 +730,7 @@ class TestDialect(Validator): "SELECT * FROM a UNION SELECT * FROM b", read={ "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a UNION SELECT * FROM b", "presto": "SELECT * FROM a UNION SELECT * FROM b", "spark": "SELECT * FROM a UNION SELECT * FROM b", @@ -719,6 +746,7 @@ class TestDialect(Validator): "SELECT * FROM a UNION ALL SELECT * FROM b", read={ "bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b", + "clickhouse": "SELECT * FROM a UNION ALL SELECT * FROM b", "duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b", "presto": "SELECT * FROM a UNION ALL SELECT * FROM b", "spark": "SELECT * FROM a UNION ALL SELECT * FROM b", @@ -848,15 +876,28 @@ class TestDialect(Validator): "postgres": "STRPOS(x, ' ')", "presto": "STRPOS(x, ' ')", "spark": "LOCATE(' ', x)", + "clickhouse": "position(x, ' ')", + "snowflake": "POSITION(' ', x)", }, ) self.validate_all( - "STR_POSITION(x, 'a')", + "STR_POSITION('a', x)", write={ "duckdb": "STRPOS(x, 'a')", "postgres": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')", "spark": "LOCATE('a', x)", + "clickhouse": "position(x, 'a')", + "snowflake": "POSITION('a', x)", + }, + ) + self.validate_all( + "POSITION('a', x, 3)", + write={ + "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "spark": "LOCATE('a', x, 3)", + "clickhouse": "position(x, 'a', 3)", + "snowflake": "POSITION('a', x, 3)", }, ) self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index d335921..acb3be9 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -247,7 +247,7 @@ class TestHive(Validator): "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))", "hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))", "spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))", - "": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))", + "": "DATEDIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))", }, ) self.validate_all( @@ -295,7 +295,7 @@ class TestHive(Validator): "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))", "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", - "": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", + "": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", }, ) self.validate_all( @@ -450,11 +450,21 @@ class TestHive(Validator): ) self.validate_all( "MAP(a, b, c, d)", + read={ + "": "VAR_MAP(a, b, c, d)", + "clickhouse": "map(a, b, c, d)", + "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", + "hive": "MAP(a, b, c, d)", + "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", + "spark": "MAP(a, b, c, d)", + }, write={ + "": "MAP(ARRAY(a, c), ARRAY(b, d))", + "clickhouse": "map(a, b, c, d)", "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "hive": "MAP(a, b, c, d)", - "spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))", + "spark": "MAP(a, b, c, d)", }, ) self.validate_all( @@ -463,7 +473,7 @@ class TestHive(Validator): "duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))", "presto": "MAP(ARRAY[a], ARRAY[b])", "hive": "MAP(a, b)", - "spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))", + "spark": "MAP(a, b)", }, ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index e0934d7..dc93c3a 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -67,6 +67,7 @@ class TestPostgres(Validator): self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") + self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 2145966..8a33e2d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -305,3 +305,35 @@ class TestSnowflake(Validator): self.validate_identity( "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'" ) + + def test_table_literal(self): + # All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html + self.validate_all( + r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""} + ) + + self.validate_all( + r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')""", + write={"snowflake": r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')"""}, + ) + + # Per Snowflake documentation at https://docs.snowflake.com/en/sql-reference/literals-table.html + # one can use either a " ' " or " $$ " to enclose the object identifier. + # Capturing the single tokens seems like lot of work. Hence adjusting tests to use these interchangeably, + self.validate_all( + r"""SELECT * FROM TABLE($$MYDB. "MYSCHEMA"."MYTABLE"$$)""", + write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""}, + ) + + self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}) + + self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}) + + self.validate_all( + r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""} + ) + + self.validate_all( + r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""", + write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""}, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8377e47..9a7e64c 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -111,12 +111,70 @@ TBLPROPERTIES ( "SELECT /*+ COALESCE(3) */ * FROM x", write={ "spark": "SELECT /*+ COALESCE(3) */ * FROM x", + "bigquery": "SELECT * FROM x", }, ) self.validate_all( "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", write={ "spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", + "bigquery": "SELECT * FROM x", + }, + ) + self.validate_all( + "SELECT /*+ BROADCAST(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ BROADCAST(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ BROADCASTJOIN(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ BROADCASTJOIN(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ MAPJOIN(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ MAPJOIN(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ MERGE(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ MERGE(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ MERGEJOIN(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ MERGEJOIN(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", + }, + ) + self.validate_all( + "SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table", + write={ + "spark": "SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table", + "bigquery": "SELECT cola FROM table", }, ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index a0de281..40e7cc1 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -321,6 +321,10 @@ SELECT 1 FROM a INNER JOIN b ON a.x = b.x SELECT 1 FROM a LEFT JOIN b ON a.x = b.x SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x SELECT 1 FROM a CROSS JOIN b ON a.x = b.x +SELECT 1 FROM a LEFT SEMI JOIN b ON a.x = b.x +SELECT 1 FROM a LEFT ANTI JOIN b ON a.x = b.x +SELECT 1 FROM a RIGHT SEMI JOIN b ON a.x = b.x +SELECT 1 FROM a RIGHT ANTI JOIN b ON a.x = b.x SELECT 1 FROM a JOIN b USING (x) SELECT 1 FROM a JOIN b USING (x, y, z) SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2 @@ -529,12 +533,14 @@ UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234 UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234 TRUNCATE TABLE x OPTIMIZE TABLE y +VACUUM FREEZE my_table WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a WITH a AS (SELECT * FROM b) UPDATE a SET col = 1 WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a WITH a AS (SELECT * FROM b) DELETE FROM a WITH a AS (SELECT * FROM b) CACHE TABLE a SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? +SELECT :hello, ? FROM x LIMIT :my_limit WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index e13d3b3..c8186cc 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -1,107 +1,189 @@ --- Simple +# title: Simple SELECT a, b FROM (SELECT a, b FROM x); SELECT x.a AS a, x.b AS b FROM x AS x; --- Inner table alias is merged +# title: Inner table alias is merged SELECT a, b FROM (SELECT a, b FROM x AS q) AS r; SELECT q.a AS a, q.b AS b FROM x AS q; --- Double nesting +# title: Double nesting SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x)); SELECT x.a AS a, x.b AS b FROM x AS x; --- WHERE clause is merged -SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a; -SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; +# title: WHERE clause is merged +SELECT a, SUM(b) AS b FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a; +SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a; --- Outer query has join -SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; -SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; - --- Outer query has join +# title: Outer query has join SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; +# title: Leave tables isolated # leave_tables_isolated: true SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b; --- Join on derived table +# title: Join on derived table SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; --- Inner query has a join +# title: Inner query has a join SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b); SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; --- Inner query has conflicting name in outer query +# title: Inner query has conflicting name in outer query SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b; SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b; --- Inner query has conflicting name in joined source +# title: Inner query has conflicting name in joined source SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b; SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b; --- Inner query has multiple conflicting names -SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b; -SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b; +# title: Inner query has multiple conflicting names +SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b ORDER BY x.a, q.c, r.c; +SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b ORDER BY q_2.a, q.c, r.c; --- Inner queries have conflicting names with each other +# title: Inner queries have conflicting names with each other SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b; SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b; --- WHERE clause in joined derived table is merged to ON clause -SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y; -SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1; +# title: WHERE clause in joined derived table is merged to ON clause +SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b AND y.c > 1; --- Comma JOIN in outer query +# title: Comma JOIN in outer query SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y; SELECT x.a AS a, y.c AS c FROM x AS x, y AS y; --- Comma JOIN in inner query +# title: Comma JOIN in inner query SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x; SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z; --- (Regression) Column in ORDER BY +# title: (Regression) Column in ORDER BY SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1; --- CTE +# title: CTE WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x; SELECT x.a AS a, x.b AS b FROM x AS x; --- CTE with outer table alias +# title: CTE with outer table alias WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z; SELECT x.a AS a, x.b AS b FROM x AS x; --- Nested CTE -WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2; +# title: Nested CTE +WITH x2 AS (SELECT a FROM x), x3 AS (SELECT a FROM x2) SELECT a FROM x3; SELECT x.a AS a FROM x AS x; --- CTE WHERE clause is merged -WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a; -SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a; +# title: CTE WHERE clause is merged +WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) AS b FROM x GROUP BY a; +SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a; --- CTE Outer query has join -WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b; +# title: CTE Outer query has join +WITH x2 AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x2 AS x JOIN y ON x.b = y.b; SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1; --- CTE with inner table alias +# title: CTE with inner table alias WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z; SELECT q.a AS a, q.b AS b FROM x AS q; --- Duplicate queries to CTE -WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y; -WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y; - --- Nested CTE +# title: Nested CTE SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x); SELECT x.a AS a, x.b AS b FROM x AS x; --- Inner select is an expression +# title: Inner select is an expression SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x; SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; --- CTE select is an expression -WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x; +# title: CTE select is an expression +WITH x2 AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x2 AS x) AS x; SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; + +# title: Full outer join +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b; + +# title: Full outer join, no predicates +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM x AS x FULL OUTER JOIN y AS y ON x.b = y.b; + +# title: Left join +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b WHERE x.b = 1; + +# title: Left join, no predicates +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN y AS y ON x.b = y.b; + +# title: Right join +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b; + +# title: Right join, no predicates +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM x AS x RIGHT JOIN y AS y ON x.b = y.b; + +# title: Inner join +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x INNER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b AND y.b = 2 WHERE x.b = 1; + +# title: Inner join, no predicates +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x INNER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b; +SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b; + +# title: Cross join +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y; +SELECT x.b AS b, y.b AS b2 FROM x AS x JOIN y AS y ON y.b = 2 WHERE x.b = 1; + +# title: Cross join, no predicates +SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y) AS y; +SELECT x.b AS b, y.b AS b2 FROM x AS x CROSS JOIN y AS y; + +# title: Broadcast hint +# dialect: spark +WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(k) */ m.a, k.c FROM m JOIN n AS k ON m.b = k.b) SELECT joined.a, joined.c FROM joined; +SELECT /*+ BROADCAST(y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +# title: Broadcast hint multiple tables +# dialect: spark +WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined; +SELECT /*+ BROADCAST(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +# title: Multiple Table Hints +# dialect: spark +WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined; +SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +# title: Mix Table and Column Hints +# dialect: spark +WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT /*+ COALESCE(3) */ joined.a, joined.c FROM joined; +SELECT /*+ COALESCE(3), BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +# title: Hint Subquery +# dialect: spark +SELECT + subquery.a, + subquery.c +FROM ( + SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM (SELECT x.a, x.b FROM x) AS m JOIN (SELECT y.b, y.c FROM y) AS n ON m.b = n.b +) AS subquery; +SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +# title: Subquery Test +# dialect: spark +SELECT /*+ BROADCAST(x) */ + x.a, + x.c +FROM ( + SELECT + x.a, + x.c + FROM ( + SELECT + x.a, + COUNT(1) AS c + FROM x + GROUP BY x.a + ) AS x +) AS x; +SELECT /*+ BROADCAST(x) */ x.a AS a, x.c AS c FROM (SELECT x.a AS a, COUNT(1) AS c FROM x AS x GROUP BY x.a) AS x; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index eb6761a..ab4f769 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -1,3 +1,5 @@ +# title: lateral +# execute: false SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m; SELECT "z"."a" AS "a", @@ -6,11 +8,13 @@ FROM "z" AS "z" LATERAL VIEW EXPLODE(ARRAY(1, 2)) q AS "m"; +# title: unnest SELECT x FROM UNNEST([1, 2]) AS q(x, y); SELECT "q"."x" AS "x" FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y"); +# title: Union in CTE WITH cte AS ( ( SELECT @@ -21,7 +25,7 @@ WITH cte AS ( UNION ALL ( SELECT - a + b AS a FROM y ) @@ -39,7 +43,7 @@ WITH "cte" AS ( UNION ALL ( SELECT - "y"."a" AS "a" + "y"."b" AS "a" FROM "y" AS "y" ) ) @@ -47,6 +51,7 @@ SELECT "cte"."a" AS "a" FROM "cte"; +# title: Chained CTEs WITH cte1 AS ( SELECT a FROM x @@ -74,30 +79,31 @@ SELECT "cte1"."a" + 1 AS "a" FROM "cte1"; -SELECT a, SUM(b) +# title: Correlated subquery +SELECT a, SUM(b) AS sum_b FROM ( SELECT x.a, y.b FROM x, y - WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a + WHERE (SELECT max(b) FROM y WHERE x.b = y.b) >= 0 AND x.b = y.b ) d WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 GROUP BY a; WITH "_u_0" AS ( SELECT MAX("y"."b") AS "_col_0", - "y"."a" AS "_u_1" + "y"."b" AS "_u_1" FROM "y" AS "y" GROUP BY - "y"."a" + "y"."b" ) SELECT "x"."a" AS "a", - SUM("y"."b") AS "_col_1" + SUM("y"."b") AS "sum_b" FROM "x" AS "x" LEFT JOIN "_u_0" AS "_u_0" - ON "x"."a" = "_u_0"."_u_1" + ON "x"."b" = "_u_0"."_u_1" JOIN "y" AS "y" - ON "x"."a" = "y"."a" + ON "x"."b" = "y"."b" WHERE "_u_0"."_col_0" >= 0 AND "x"."a" > 1 @@ -105,6 +111,7 @@ WHERE GROUP BY "x"."a"; +# title: Root subquery (SELECT a FROM x) LIMIT 1; ( SELECT @@ -113,6 +120,7 @@ GROUP BY ) LIMIT 1; +# title: Root subquery is union (SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; ( SELECT @@ -125,6 +133,7 @@ LIMIT 1; ) LIMIT 1; +# title: broadcast # dialect: spark SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b; SELECT /*+ BROADCAST(`y`) */ @@ -133,11 +142,14 @@ FROM `x` AS `x` JOIN `y` AS `y` ON `x`.`b` = `y`.`b`; +# title: aggregate +# execute: false SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg" FROM "x" AS "x"; +# title: values SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); SELECT "tab"."cola" AS "cola", @@ -146,6 +158,7 @@ FROM (VALUES (1, 'test'), (2, 'test2')) AS "tab"("cola", "colb"); +# title: spark values # dialect: spark SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); SELECT @@ -154,3 +167,112 @@ SELECT FROM VALUES (1, 'test'), (2, 'test2') AS `tab`(`cola`, `colb`); + +# title: complex CTE dependencies +WITH m AS ( + SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b) +), n AS ( + SELECT a, b FROM m WHERE m.a = 1 +), o AS ( + SELECT a, b FROM m WHERE m.a = 2 +) SELECT + n.a, + n.b, + o.b +FROM n +FULL OUTER JOIN o ON n.a = o.a +CROSS JOIN n AS n2 +WHERE o.b > 0 AND n.a = n2.a; +WITH "m" AS ( + SELECT + "a1"."a" AS "a", + "a1"."b" AS "b" + FROM (VALUES + (1, 2)) AS "a1"("a", "b") +), "n" AS ( + SELECT + "m"."a" AS "a", + "m"."b" AS "b" + FROM "m" + WHERE + "m"."a" = 1 +), "o" AS ( + SELECT + "m"."a" AS "a", + "m"."b" AS "b" + FROM "m" + WHERE + "m"."a" = 2 +) +SELECT + "n"."a" AS "a", + "n"."b" AS "b", + "o"."b" AS "b" +FROM "n" +FULL JOIN "o" + ON "n"."a" = "o"."a" +JOIN "n" AS "n2" + ON "n"."a" = "n2"."a" +WHERE + "o"."b" > 0; + +# title: Broadcast hint +# dialect: spark +WITH m AS ( + SELECT + x.a, + x.b + FROM x +), n AS ( + SELECT + y.b, + y.c + FROM y +), joined as ( + SELECT /*+ BROADCAST(n) */ + m.a, + n.c + FROM m JOIN n ON m.b = n.b +) +SELECT + joined.a, + joined.c +FROM joined; +SELECT /*+ BROADCAST(`y`) */ + `x`.`a` AS `a`, + `y`.`c` AS `c` +FROM `x` AS `x` +JOIN `y` AS `y` + ON `x`.`b` = `y`.`b`; + +# title: Mix Table and Column Hints +# dialect: spark +WITH m AS ( + SELECT + x.a, + x.b + FROM x +), n AS ( + SELECT + y.b, + y.c + FROM y +), joined as ( + SELECT /*+ BROADCAST(m), MERGE(m, n) */ + m.a, + n.c + FROM m JOIN n ON m.b = n.b +) +SELECT + /*+ COALESCE(3) */ + joined.a, + joined.c +FROM joined; +SELECT /*+ COALESCE(3), + BROADCAST(`x`), + MERGE(`x`, `y`) */ + `x`.`a` AS `a`, + `y`.`c` AS `c` +FROM `x` AS `x` +JOIN `y` AS `y` + ON `x`.`b` = `y`.`b`; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index f848e7a..83a3bf8 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -19,38 +19,49 @@ SELECT x.a AS a FROM x AS x; SELECT a AS b FROM x; SELECT x.a AS b FROM x AS x; +# execute: false SELECT 1, 2 FROM x; SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x; +# execute: false SELECT a + b FROM x; SELECT x.a + x.b AS "_col_0" FROM x AS x; -SELECT a + b FROM x; -SELECT x.a + x.b AS "_col_0" FROM x AS x; - +# execute: false SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; SELECT a AS j, b FROM x ORDER BY j; SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; -SELECT a AS j, b FROM x GROUP BY j; -SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a; +SELECT a AS j, b AS a FROM x ORDER BY 1; +SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a; + +SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2; +SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b); + +# execute: false +SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2; +SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); + +SELECT a AS j, b FROM x GROUP BY j, b; +SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b; SELECT a, b FROM x GROUP BY 1, 2; SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b; SELECT a, b FROM x ORDER BY 1, 2; -SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b; +# execute: false SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2; SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); -SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c; -SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; +SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c; +SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; -SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d; -SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a); +SELECT COALESCE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d; +SELECT COALESCE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY COALESCE(x.a); SELECT a AS a, b FROM x ORDER BY a; SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; @@ -69,6 +80,7 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; +# execute: false SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x; @@ -93,8 +105,8 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; SELECT a FROM (SELECT a FROM (SELECT a FROM x)); SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1"; -SELECT x.a FROM x AS x JOIN (SELECT * FROM x); -SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; +SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a; +SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a; -------------------------------------- -- Joins @@ -123,6 +135,7 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FRO SELECT a FROM x WHERE b IN (SELECT c FROM y); SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y); +# execute: false SELECT (SELECT c FROM y) FROM x; SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x; @@ -144,10 +157,12 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x); SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b)); SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b)); +# execute: false # dialect: bigquery SELECT aa FROM x, UNNEST(a) AS aa; SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa; +# execute: false SELECT aa FROM x, UNNEST(a) AS t(aa); SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa); @@ -205,15 +220,19 @@ WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) O -------------------------------------- -- Except and Replace -------------------------------------- +# execute: false SELECT * REPLACE(a AS d) FROM x; SELECT x.a AS d, x.b AS b FROM x AS x; +# execute: false SELECT * EXCEPT(b) REPLACE(a AS d) FROM x; SELECT x.a AS d FROM x AS x; +# execute: false SELECT x.* EXCEPT(a), y.* FROM x, y; SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y; +# execute: false SELECT * EXCEPT(a) FROM x; SELECT x.b AS b FROM x AS x; diff --git a/tests/fixtures/optimizer/qualify_columns__with_invisible.sql b/tests/fixtures/optimizer/qualify_columns__with_invisible.sql new file mode 100644 index 0000000..ee46c23 --- /dev/null +++ b/tests/fixtures/optimizer/qualify_columns__with_invisible.sql @@ -0,0 +1,35 @@ +-------------------------------------- +-- Qualify columns +-------------------------------------- +SELECT a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT b FROM x; +SELECT x.b AS b FROM x AS x; + +-------------------------------------- +-- Derived tables +-------------------------------------- +SELECT x.a FROM x AS x JOIN (SELECT * FROM x); +SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT x.b FROM x AS x JOIN (SELECT b FROM x); +SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- Expand * +-------------------------------------- +SELECT * FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT * FROM y JOIN z ON y.b = z.b; +SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.b = z.b; + +SELECT * FROM y JOIN z ON y.c = z.c; +SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c; + +SELECT a FROM (SELECT * FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT * FROM (SELECT a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index d7217cf..07e818f 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -52,6 +52,9 @@ TRUE; NULL AND TRUE; NULL; +NULL AND FALSE; +FALSE; + NULL AND NULL; NULL; @@ -70,6 +73,9 @@ FALSE; NOT FALSE; TRUE; +NOT NULL; +NULL; + NULL = NULL; NULL; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index d2f10fc..936a0af 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -769,13 +769,20 @@ group by order by custdist desc, c_count desc; -WITH "c_orders" AS ( +WITH "orders_2" AS ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_comment" AS "o_comment" + FROM "orders" AS "orders" + WHERE + NOT "orders"."o_comment" LIKE '%special%requests%' +), "c_orders" AS ( SELECT COUNT("orders"."o_orderkey") AS "c_count" FROM "customer" AS "customer" - LEFT JOIN "orders" AS "orders" + LEFT JOIN "orders_2" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" - AND NOT "orders"."o_comment" LIKE '%special%requests%' GROUP BY "customer"."c_custkey" ) diff --git a/tests/helpers.py b/tests/helpers.py index ad50483..2d200f6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -45,6 +45,14 @@ def load_sql_fixture_pairs(filename): yield meta, sql, expected +def string_to_bool(string): + if string is None: + return False + if string in (True, False): + return string + return string and string.lower() in ("true", "1") + + TPCH_SCHEMA = { "lineitem": { "l_orderkey": "uint64", diff --git a/tests/test_build.py b/tests/test_build.py index b5d657c..fa9e7f8 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,6 +1,19 @@ import unittest -from sqlglot import and_, condition, exp, from_, not_, or_, parse_one, select +from sqlglot import ( + alias, + and_, + condition, + except_, + exp, + from_, + intersect, + not_, + or_, + parse_one, + select, + union, +) class TestBuild(unittest.TestCase): @@ -320,6 +333,54 @@ class TestBuild(unittest.TestCase): lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), "UPDATE tbl SET x = 1 FROM tbl2", ), + ( + lambda: union("SELECT * FROM foo", "SELECT * FROM bla"), + "SELECT * FROM foo UNION SELECT * FROM bla", + ), + ( + lambda: parse_one("SELECT * FROM foo").union("SELECT * FROM bla"), + "SELECT * FROM foo UNION SELECT * FROM bla", + ), + ( + lambda: intersect("SELECT * FROM foo", "SELECT * FROM bla"), + "SELECT * FROM foo INTERSECT SELECT * FROM bla", + ), + ( + lambda: parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla"), + "SELECT * FROM foo INTERSECT SELECT * FROM bla", + ), + ( + lambda: except_("SELECT * FROM foo", "SELECT * FROM bla"), + "SELECT * FROM foo EXCEPT SELECT * FROM bla", + ), + ( + lambda: parse_one("SELECT * FROM foo").except_("SELECT * FROM bla"), + "SELECT * FROM foo EXCEPT SELECT * FROM bla", + ), + ( + lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla"), + "(SELECT * FROM foo) UNION SELECT * FROM bla", + ), + ( + lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla", distinct=False), + "(SELECT * FROM foo) UNION ALL SELECT * FROM bla", + ), + ( + lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y)"), "a"), + "LAG(x) OVER (PARTITION BY y) AS a", + ), + ( + lambda: alias(parse_one("LAG(x) OVER (ORDER BY z)"), "a"), + "LAG(x) OVER (ORDER BY z) AS a", + ), + ( + lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y ORDER BY z)"), "a"), + "LAG(x) OVER (PARTITION BY y ORDER BY z) AS a", + ), + ( + lambda: alias(parse_one("LAG(x) OVER ()"), "a"), + "LAG(x) OVER () AS a", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index cc41307..abc95cb 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -115,6 +115,21 @@ class TestExpressions(unittest.TestCase): ["first", "second", "third"], ) + def test_table_name(self): + self.assertEqual(exp.table_name(parse_one("a", into=exp.Table)), "a") + self.assertEqual(exp.table_name(parse_one("a.b", into=exp.Table)), "a.b") + self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c") + self.assertEqual(exp.table_name("a.b.c"), "a.b.c") + + def test_replace_tables(self): + self.assertEqual( + exp.replace_tables( + parse_one("select * from a join b join c.a join d.a join e.a"), + {"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"}, + ).sql(), + 'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a', + ) + def test_named_selects(self): expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) @@ -474,3 +489,10 @@ class TestExpressions(unittest.TestCase): ]: with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) + + def test_annotation_alias(self): + expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo") + self.assertEqual( + [e.alias_or_name for e in expression.expressions], + ["a", "B", "c", "D"], + ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index aad84ed..36a7785 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,17 +1,55 @@ import unittest from functools import partial +import duckdb +from pandas.testing import assert_frame_equal + +import sqlglot from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope -from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures +from tests.helpers import ( + TPCH_SCHEMA, + load_sql_fixture_pairs, + load_sql_fixtures, + string_to_bool, +) class TestOptimizer(unittest.TestCase): maxDiff = None + @classmethod + def setUpClass(cls): + cls.conn = duckdb.connect() + cls.conn.execute( + """ + CREATE TABLE x (a INT, b INT); + CREATE TABLE y (b INT, c INT); + CREATE TABLE z (b INT, c INT); + + INSERT INTO x VALUES (1, 1); + INSERT INTO x VALUES (2, 2); + INSERT INTO x VALUES (2, 2); + INSERT INTO x VALUES (3, 3); + INSERT INTO x VALUES (null, null); + + INSERT INTO y VALUES (2, 2); + INSERT INTO y VALUES (2, 2); + INSERT INTO y VALUES (3, 3); + INSERT INTO y VALUES (4, 4); + INSERT INTO y VALUES (null, null); + + INSERT INTO y VALUES (3, 3); + INSERT INTO y VALUES (3, 3); + INSERT INTO y VALUES (4, 4); + INSERT INTO y VALUES (5, 5); + INSERT INTO y VALUES (null, null); + """ + ) + def setUp(self): self.schema = { "x": { @@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase): }, } - def check_file(self, file, func, pretty=False, **kwargs): + def check_file(self, file, func, pretty=False, execute=False, **kwargs): for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): + title = meta.get("title") or f"{i}, {sql}" dialect = meta.get("dialect") leave_tables_isolated = meta.get("leave_tables_isolated") func_kwargs = {**kwargs} if leave_tables_isolated is not None: - func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1") + func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) + + optimized = func(parse_one(sql, read=dialect), **func_kwargs) - with self.subTest(f"{i}, {sql}"): + with self.subTest(title): self.assertEqual( - func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect), + optimized.sql(pretty=pretty, dialect=dialect), expected, ) + should_execute = meta.get("execute") + if should_execute is None: + should_execute = execute + + if string_to_bool(should_execute): + with self.subTest(f"(execute) {title}"): + df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df() + df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() + assert_frame_equal(df1, df2) + def test_optimize(self): schema = { "x": {"a": "INT", "b": "INT"}, - "y": {"a": "INT", "b": "INT"}, + "y": {"b": "INT", "c": "INT"}, "z": {"a": "INT", "c": "INT"}, } - self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema) + self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema) def test_isolate_table_selects(self): self.check_file( @@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase): expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) return expression - self.check_file("qualify_columns", qualify_columns, schema=self.schema) + self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) + + def test_qualify_columns__with_invisible(self): + def qualify_columns(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + return expression + + schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}}) + self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema) def test_qualify_columns__invalid(self): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): @@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase): ], ) - self.check_file("merge_subqueries", optimize, schema=self.schema) + self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema) def test_eliminate_subqueries(self): self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) @@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') } for sql, target_type in tests.items(): - expression = parse_one(sql) - annotated_expression = annotate_types(expression) - - self.assertEqual(annotated_expression.find(exp.Literal).type, target_type) + expression = annotate_types(parse_one(sql)) + self.assertEqual(expression.find(exp.Literal).type, target_type) def test_boolean_type_annotation(self): tests = { @@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') } for sql, target_type in tests.items(): - expression = parse_one(sql) - annotated_expression = annotate_types(expression) - - self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type) + expression = annotate_types(parse_one(sql)) + self.assertEqual(expression.find(exp.Boolean).type, target_type) def test_cast_type_annotation(self): - expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))") - annotate_types(expression) + expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) @@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) def test_cache_annotation(self): - expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") - annotated_expression = annotate_types(expression) - - self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT) + expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")) + self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) def test_binary_annotation(self): - expression = parse_one("SELECT 0.0 + (2 + 3)") - annotate_types(expression) - - expression = expression.expressions[0] + expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0] self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) @@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) + + def test_derived_tables_column_annotation(self): + schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}} + sql = """ + SELECT a.cola AS cola + FROM ( + SELECT x.cola + y.cola AS cola + FROM ( + SELECT x.cola AS cola + FROM x AS x + ) AS x + JOIN ( + SELECT y.cola AS cola + FROM y AS y + ) AS y + ) AS a + """ + + expression = annotate_types(parse_one(sql), schema=schema) + self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola + + addition_alias = expression.args["from"].expressions[0].this.expressions[0] + self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola + + addition = addition_alias.this + self.assertEqual(addition.type, exp.DataType.Type.FLOAT) + self.assertEqual(addition.this.type, exp.DataType.Type.INT) + self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT) + + def test_cte_column_annotation(self): + schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}} + sql = """ + WITH tbl AS ( + SELECT x.cola + 'bla' AS cola, y.colb AS colb + FROM ( + SELECT x.cola AS cola + FROM x AS x + ) AS x + JOIN ( + SELECT y.colb AS colb + FROM y AS y + ) AS y + ) + SELECT tbl.cola + tbl.colb + 'foo' AS col + FROM tbl AS tbl + """ + + expression = annotate_types(parse_one(sql), schema=schema) + self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col + + outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' + self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR) + + inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb + self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR) + self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) + + cte_select = expression.args["with"].expressions[0].this + self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola + self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb + + cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' + self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR) + self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) + + # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively + for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]): + self.assertEqual(d.this.expressions[0].this.type, t) + + def test_function_annotation(self): + schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} + sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x" + + concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR) + + concat_expr = concat_expr_alias.this + self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) + self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + + def test_unknown_annotation(self): + schema = {"x": {"cola": "VARCHAR"}} + sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" + + concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN) + + concat_expr = concat_expr_alias.this + self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola) + self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg) + + def test_null_annotation(self): + expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this + self.assertEqual(expression.left.type, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type, exp.DataType.Type.INT) + + # NULL UNKNOWN should yield NULL + sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" + + concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] + self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL) + + concat_expr = concat_expr_alias.this + self.assertEqual(concat_expr.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) + + def test_nullable_annotation(self): + nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + expression = annotate_types(parse_one("NULL AND FALSE")) + + self.assertEqual(expression.type, nullable) + self.assertEqual(expression.left.type, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 4bec2ac..01b8205 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -338,7 +338,7 @@ class TestTranspile(unittest.TestCase): unsupported_level=level, ) - error = "Cannot convert array columns into map use SparkSQL instead." + error = "Cannot convert array columns into map." unsupported(ErrorLevel.WARN) assert_logger_contains("\n".join([error] * 4), logger, level="warning") -- cgit v1.2.3