From f818ab3b896d52e874634b7c4db3533078c1887f Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 10 Oct 2022 13:29:05 +0200 Subject: Merging upstream version 6.3.1. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 5 +- sqlglot/dialects/bigquery.py | 6 +- sqlglot/dialects/clickhouse.py | 33 ++++- sqlglot/dialects/dialect.py | 17 ++- sqlglot/dialects/hive.py | 40 +----- sqlglot/dialects/oracle.py | 29 ++++ sqlglot/dialects/postgres.py | 14 +- sqlglot/dialects/snowflake.py | 5 +- sqlglot/dialects/spark.py | 21 ++- sqlglot/executor/python.py | 1 + sqlglot/expressions.py | 222 +++++++++++++++++++++++++++++-- sqlglot/generator.py | 29 ++-- sqlglot/optimizer/annotate_types.py | 158 +++++++++++++++++++--- sqlglot/optimizer/merge_subqueries.py | 44 +++++- sqlglot/optimizer/pushdown_predicates.py | 38 ++++-- sqlglot/optimizer/qualify_columns.py | 8 +- sqlglot/optimizer/schema.py | 63 ++++++++- sqlglot/optimizer/scope.py | 20 ++- sqlglot/optimizer/simplify.py | 8 +- sqlglot/parser.py | 121 +++++++++++------ sqlglot/tokens.py | 8 ++ 21 files changed, 726 insertions(+), 164 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 1f7b28c..0228bdd 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -8,7 +8,9 @@ from sqlglot.expressions import ( and_, column, condition, + except_, from_, + intersect, maybe_parse, not_, or_, @@ -16,11 +18,12 @@ from sqlglot.expressions import ( subquery, ) from sqlglot.expressions import table_ as table +from sqlglot.expressions import union from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.2.8" +__version__ = "6.3.1" pretty = False diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 40298e7..86e46cf 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -135,6 +135,7 @@ class BigQuery(Dialect): exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), + exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.ILike: no_ilike_sql, exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -172,12 +173,11 @@ class BigQuery(Dialect): exp.AnonymousProperty, } + EXPLICIT_UNION = True + def in_unnest_op(self, unnest): return self.sql(unnest) - def union_op(self, expression): - return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" - def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 55dad7a..da5c856 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,10 +1,16 @@ from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.generator import Generator -from sqlglot.parser import Parser +from sqlglot.helper import csv +from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer, TokenType +def _lower_func(sql): + index = sql.index("(") + return sql[:index].lower() + sql[index:] + + class ClickHouse(Dialect): normalize_functions = None null_ordering = "nulls_are_last" @@ -14,17 +20,23 @@ class ClickHouse(Dialect): KEYWORDS = { **Tokenizer.KEYWORDS, - "NULLABLE": TokenType.NULLABLE, "FINAL": TokenType.FINAL, + "DATETIME64": TokenType.DATETIME, "INT8": TokenType.TINYINT, "INT16": TokenType.SMALLINT, "INT32": TokenType.INT, "INT64": TokenType.BIGINT, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, + "TUPLE": TokenType.STRUCT, } class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "MAP": parse_var_map, + } + def _parse_table(self, schema=False): this = super()._parse_table(schema) @@ -39,10 +51,25 @@ class ClickHouse(Dialect): TYPE_MAPPING = { **Generator.TYPE_MAPPING, exp.DataType.Type.NULLABLE: "Nullable", + exp.DataType.Type.DATETIME: "DateTime64", + exp.DataType.Type.MAP: "Map", + exp.DataType.Type.ARRAY: "Array", + exp.DataType.Type.STRUCT: "Tuple", + exp.DataType.Type.TINYINT: "Int8", + exp.DataType.Type.SMALLINT: "Int16", + exp.DataType.Type.INT: "Int32", + exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.FLOAT: "Float32", + exp.DataType.Type.DOUBLE: "Float64", } TRANSFORMS = { **Generator.TRANSFORMS, exp.Array: inline_array_sql, + exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), + exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } + + EXPLICIT_UNION = True diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 98dc330..f7c6cb5 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -77,7 +77,6 @@ class Dialect(metaclass=_Dialect): alias_post_tablesample = False normalize_functions = "upper" null_ordering = "nulls_are_small" - wrap_derived_values = True date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" @@ -170,7 +169,6 @@ class Dialect(metaclass=_Dialect): "alias_post_tablesample": self.alias_post_tablesample, "normalize_functions": self.normalize_functions, "null_ordering": self.null_ordering, - "wrap_derived_values": self.wrap_derived_values, **opts, } ) @@ -271,6 +269,21 @@ def struct_extract_sql(self, expression): return f"{this}.{struct_key}" +def var_map_sql(self, expression): + keys = expression.args["keys"] + values = expression.args["values"] + + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + self.unsupported("Cannot convert array columns into map.") + return f"MAP({self.sql(keys)}, {self.sql(values)})" + + args = [] + for key, value in zip(keys.expressions, values.expressions): + args.append(self.sql(key)) + args.append(self.sql(value)) + return f"MAP({csv(*args)})" + + def format_time_lambda(exp_class, dialect, default=None): """Helper used for time expressions. diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 7a27bb3..55d7bcc 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -11,40 +11,14 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, struct_extract_sql, + var_map_sql, ) from sqlglot.generator import Generator from sqlglot.helper import csv, list_get -from sqlglot.parser import Parser +from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer -def _parse_map(args): - keys = [] - values = [] - for i in range(0, len(args), 2): - keys.append(args[i]) - values.append(args[i + 1]) - return HiveMap( - keys=exp.Array(expressions=keys), - values=exp.Array(expressions=values), - ) - - -def _map_sql(self, expression): - keys = expression.args["keys"] - values = expression.args["values"] - - if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): - self.unsupported("Cannot convert array columns into map use SparkSQL instead.") - return f"MAP({self.sql(keys)}, {self.sql(values)})" - - args = [] - for key, value in zip(keys.expressions, values.expressions): - args.append(self.sql(key)) - args.append(self.sql(value)) - return f"MAP({csv(*args)})" - - def _array_sort(self, expression): if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") @@ -122,10 +96,6 @@ def _index_sql(self, expression): return f"{this} ON TABLE {table} {columns}" -class HiveMap(exp.Map): - is_var_len_args = True - - class Hive(Dialect): alias_post_tablesample = True @@ -206,7 +176,7 @@ class Hive(Dialect): position=list_get(args, 2), ), "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), - "MAP": _parse_map, + "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, @@ -245,8 +215,8 @@ class Hive(Dialect): exp.Join: _unnest_to_explode_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), - exp.Map: _map_sql, - HiveMap: _map_sql, + exp.Map: var_map_sql, + exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 91e30b2..8041ff0 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -10,6 +10,32 @@ def _limit_sql(self, expression): class Oracle(Dialect): + # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 + # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes + time_mapping = { + "AM": "%p", # Meridian indicator with or without periods + "A.M.": "%p", # Meridian indicator with or without periods + "PM": "%p", # Meridian indicator with or without periods + "P.M.": "%p", # Meridian indicator with or without periods + "D": "%u", # Day of week (1-7) + "DAY": "%A", # name of day + "DD": "%d", # day of month (1-31) + "DDD": "%j", # day of year (1-366) + "DY": "%a", # abbreviated name of day + "HH": "%I", # Hour of day (1-12) + "HH12": "%I", # alias for HH + "HH24": "%H", # Hour of day (0-23) + "IW": "%V", # Calendar week of year (1-52 or 1-53), as defined by the ISO 8601 standard + "MI": "%M", # Minute (0-59) + "MM": "%m", # Month (01-12; January = 01) + "MON": "%b", # Abbreviated name of month + "MONTH": "%B", # Name of month + "SS": "%S", # Second (0-59) + "WW": "%W", # Week of year (1-53) + "YY": "%y", # 15 + "YYYY": "%Y", # 2015 + } + class Generator(Generator): TYPE_MAPPING = { **Generator.TYPE_MAPPING, @@ -30,6 +56,9 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", } def query_modifiers(self, expression, *sqls): diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index aaa07a1..731e28e 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -118,13 +118,22 @@ def _serial_to_generated(expression): return expression +def _to_timestamp(args): + # TO_TIMESTAMP accepts either a single double argument or (text, text) + if len(args) == 1 and args[0].is_number: + # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE + return exp.UnixToTime.from_arg_list(args) + # https://www.postgresql.org/docs/current/functions-formatting.html + return format_time_lambda(exp.StrToTime, "postgres")(args) + + class Postgres(Dialect): null_ordering = "nulls_are_large" time_format = "'YYYY-MM-DD HH24:MI:SS'" time_mapping = { "AM": "%p", "PM": "%p", - "D": "%w", # 1-based day of week + "D": "%u", # 1-based day of week "DD": "%d", # day of month "DDD": "%j", # zero padded day of year "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres @@ -172,7 +181,7 @@ class Postgres(Dialect): FUNCTIONS = { **Parser.FUNCTIONS, - "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), + "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } @@ -211,4 +220,5 @@ class Postgres(Dialect): exp.TableSample: no_tablesample_sql, exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, + exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index fb2d900..19a427c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -121,6 +121,7 @@ class Snowflake(Dialect): FUNC_TOKENS = { *Parser.FUNC_TOKENS, TokenType.RLIKE, + TokenType.TABLE, } COLUMN_OPERATORS = { @@ -143,7 +144,7 @@ class Snowflake(Dialect): SINGLE_TOKENS = { **Tokenizer.SINGLE_TOKENS, - "$": TokenType.DOLLAR, # needed to break for quotes + "$": TokenType.PARAMETER, } KEYWORDS = { @@ -164,6 +165,8 @@ class Snowflake(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: _unix_to_time, exp.Array: inline_array_sql, + exp.StrPosition: rename_func("POSITION"), + exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", } diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index e8da07a..95a7ab4 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -4,8 +4,9 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, rename_func, ) -from sqlglot.dialects.hive import Hive, HiveMap +from sqlglot.dialects.hive import Hive from sqlglot.helper import list_get +from sqlglot.parser import Parser def _create_sql(self, e): @@ -47,8 +48,6 @@ def _unix_to_time(self, expression): class Spark(Hive): - wrap_derived_values = False - class Parser(Hive.Parser): FUNCTIONS = { **Hive.Parser.FUNCTIONS, @@ -78,8 +77,19 @@ class Spark(Hive): "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } - class Generator(Hive.Generator): + FUNCTION_PARSERS = { + **Parser.FUNCTION_PARSERS, + "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), + "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), + "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), + "MERGE": lambda self: self._parse_join_hint("MERGE"), + "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), + "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), + "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), + "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), + } + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", @@ -102,8 +112,9 @@ class Spark(Hive): exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", - HiveMap: _map_sql, } + WRAP_DERIVED_VALUES = False + class Tokenizer(Hive.Tokenizer): HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 610aa4b..8ef6cf0 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -326,6 +326,7 @@ class Python(Dialect): exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, exp.And: lambda self, e: self.binary(e, "and"), + exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Cast: _cast_py, exp.Column: _column_py, exp.EQ: lambda self, e: self.binary(e, "=="), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8cdacce..f2ffd12 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -508,7 +508,69 @@ class DerivedTable(Expression): return [select.alias_or_name for select in self.selects] -class UDTF(DerivedTable): +class Unionable: + def union(self, expression, distinct=True, dialect=None, **opts): + """ + Builds a UNION expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Union: the Union expression. + """ + return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + def intersect(self, expression, distinct=True, dialect=None, **opts): + """ + Builds an INTERSECT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Intersect: the Intersect expression + """ + return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + def except_(self, expression, distinct=True, dialect=None, **opts): + """ + Builds an EXCEPT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Except: the Except expression + """ + return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + +class UDTF(DerivedTable, Unionable): pass @@ -518,6 +580,10 @@ class Annotation(Expression): "expression": True, } + @property + def alias(self): + return self.expression.alias_or_name + class Cache(Expression): arg_types = { @@ -700,6 +766,10 @@ class Hint(Expression): arg_types = {"expressions": True} +class JoinHint(Expression): + arg_types = {"this": True, "expressions": True} + + class Identifier(Expression): arg_types = {"this": True, "quoted": False} @@ -971,7 +1041,7 @@ class Tuple(Expression): arg_types = {"expressions": False} -class Subqueryable: +class Subqueryable(Unionable): def subquery(self, alias=None, copy=True): """ Convert this expression to an aliased expression that can be used as a Subquery. @@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression): return self.expressions -class Subquery(DerivedTable): +class Subquery(DerivedTable, Unionable): arg_types = { "this": True, "alias": False, @@ -1731,7 +1801,7 @@ class Parameter(Expression): class Placeholder(Expression): - arg_types = {} + arg_types = {"this": False} class Null(Condition): @@ -1791,6 +1861,8 @@ class DataType(Expression): IMAGE = auto() VARIANT = auto() OBJECT = auto() + NULL = auto() + UNKNOWN = auto() # Sentinel value, useful for type annotation @classmethod def build(cls, dtype, **kwargs): @@ -2007,7 +2079,7 @@ class Distinct(Expression): class In(Predicate): - arg_types = {"this": True, "expressions": False, "query": False, "unnest": False} + arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} class TimeUnit(Expression): @@ -2377,6 +2449,11 @@ class Map(Func): arg_types = {"keys": True, "values": True} +class VarMap(Func): + arg_types = {"keys": True, "values": True} + is_var_len_args = True + + class Max(AggFunc): pass @@ -2449,7 +2526,7 @@ class Substring(Func): class StrPosition(Func): - arg_types = {"this": True, "substr": True, "position": False} + arg_types = {"substr": True, "this": True, "position": False} class StrToDate(Func): @@ -2785,6 +2862,81 @@ def _wrap_operator(expression): return expression +def union(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one UNION expression. + + Example: + >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Union: the syntax tree for the UNION expression. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Union(this=left, expression=right, distinct=distinct) + + +def intersect(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one INTERSECT expression. + + Example: + >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Intersect: the syntax tree for the INTERSECT expression. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Intersect(this=left, expression=right, distinct=distinct) + + +def except_(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one EXCEPT expression. + + Example: + >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Except: the syntax tree for the EXCEPT statement. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Except(this=left, expression=right, distinct=distinct) + + def select(*expressions, dialect=None, **opts): """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): If an Expression instance is passed, this is used as-is. alias (str or Identifier): the alias name to use. If the name has special characters it is quoted. - table (boolean): create a table alias, default false + table (bool): create a table alias, default false dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): alias = to_identifier(alias, quoted=quoted) alias = TableAlias(this=alias) if table else alias - if "alias" in exp.arg_types: + if "alias" in exp.arg_types and not isinstance(exp, Window): exp = exp.copy() exp.set("alias", alias) return exp @@ -3138,6 +3290,60 @@ def column_table_names(expression): return list(dict.fromkeys(column.table for column in expression.find_all(Column))) +def table_name(table): + """Get the full name of a table as a string. + + Args: + table (exp.Table | str): Table expression node or string. + + Examples: + >>> from sqlglot import exp, parse_one + >>> table_name(parse_one("select * from a.b.c").find(exp.Table)) + 'a.b.c' + + Returns: + str: the table name + """ + + table = maybe_parse(table, into=Table) + + return ".".join( + part + for part in ( + table.text("catalog"), + table.text("db"), + table.name, + ) + if part + ) + + +def replace_tables(expression, mapping): + """Replace all tables in expression according to the mapping. + + Args: + expression (sqlglot.Expression): Expression node to be transformed and replaced + mapping (Dict[str, str]): Mapping of table names + + Examples: + >>> from sqlglot import exp, parse_one + >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() + 'SELECT * FROM "c"' + + Returns: + The mapped expression + """ + + def _replace_tables(node): + if isinstance(node, Table): + new_name = mapping.get(table_name(node)) + if new_name: + return table_(*reversed(new_name.split(".")), quoted=True) + return node + + return expression.transform(_replace_tables) + + TRUE = Boolean(this=True) FALSE = Boolean(this=False) NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 8b356f3..b7e295d 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -48,8 +48,9 @@ class Generator: TRANSFORMS = { exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", - exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -57,7 +58,12 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } + # whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True + # always do union distinct or union all + EXPLICIT_UNION = False + # wrap derived values in parens, usually standard but spark doesn't support it + WRAP_DERIVED_VALUES = True TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", @@ -101,7 +107,6 @@ class Generator: "unsupported_messages", "null_ordering", "max_unsupported", - "wrap_derived_values", "_indent", "_replace_backslash", "_escaped_quote_end", @@ -130,7 +135,6 @@ class Generator: null_ordering=None, max_unsupported=3, leading_comma=False, - wrap_derived_values=True, ): import sqlglot @@ -154,7 +158,6 @@ class Generator: self.unsupported_messages = [] self.max_unsupported = max_unsupported self.null_ordering = null_ordering - self.wrap_derived_values = wrap_derived_values self._indent = indent self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end @@ -595,7 +598,7 @@ class Generator: if not alias: return f"VALUES{self.seg('')}{args}" alias = f" AS {alias}" if alias else alias - if self.wrap_derived_values: + if self.WRAP_DERIVED_VALUES: return f"(VALUES{self.seg('')}{args}){alias}" return f"VALUES{self.seg('')}{args}{alias}" @@ -779,8 +782,8 @@ class Generator: def parameter_sql(self, expression): return f"@{self.sql(expression, 'this')}" - def placeholder_sql(self, *_): - return "?" + def placeholder_sql(self, expression): + return f":{expression.name}" if expression.name else "?" def subquery_sql(self, expression): alias = self.sql(expression, "alias") @@ -803,7 +806,9 @@ class Generator: ) def union_op(self, expression): - return f"UNION{'' if expression.args.get('distinct') else ' ALL'}" + kind = " DISTINCT" if self.EXPLICIT_UNION else "" + kind = kind if expression.args.get("distinct") else " ALL" + return f"UNION{kind}" def unnest_sql(self, expression): args = self.expressions(expression, flat=True) @@ -940,10 +945,13 @@ class Generator: def in_sql(self, expression): query = expression.args.get("query") unnest = expression.args.get("unnest") + field = expression.args.get("field") if query: in_sql = self.wrap(query) elif unnest: in_sql = self.in_unnest_op(unnest) + elif field: + in_sql = self.sql(field) else: in_sql = f"({self.expressions(expression, flat=True)})" return f"{self.sql(expression, 'this')} IN {in_sql}" @@ -1178,3 +1186,8 @@ class Generator: this = self.sql(expression, "this") kind = self.sql(expression, "kind") return f"{this} {kind}" + + def joinhint_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"{this}({expressions})" diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 3f5f089..a2cef37 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,16 +1,20 @@ from sqlglot import exp from sqlglot.helper import ensure_list, subclasses +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import Scope, traverse_scope def annotate_types(expression, schema=None, annotators=None, coerces_to=None): """ Recursively infer & annotate types in an expression syntax tree against a schema. + Assumes that we've already executed the optimizer's qualify_columns step. - (TODO -- replace this with a better example after adding some functionality) Example: >>> import sqlglot - >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) - >>> annotated_expression.type + >>> schema = {"y": {"cola": "SMALLINT"}} + >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" + >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) + >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" Args: @@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): sqlglot.Expression: expression annotated with types """ + schema = ensure_schema(schema) + return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) @@ -35,10 +41,81 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_cast(expr), - exp.DataType: lambda self, expr: self._annotate_data_type(expr), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_boolean(expr), + exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), + exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html @@ -97,43 +174,82 @@ class TypeAnnotator: }, } + TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + def __init__(self, schema=None, annotators=None, coerces_to=None): self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO def annotate(self, expression): + if isinstance(expression, self.TRAVERSABLES): + for scope in traverse_scope(expression): + subscope_selects = { + name: {select.alias_or_name: select for select in source.selects} + for name, source in scope.sources.items() + if isinstance(source, Scope) + } + + # First annotate the current scope's column references + for col in scope.columns: + source = scope.sources[col.table] + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + else: + col.type = subscope_selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + + return self._maybe_annotate(expression) # This takes care of non-traversable expressions + + def _maybe_annotate(self, expression): if not isinstance(expression, exp.Expression): return None + if expression.type: + return expression # We've already inferred the expression's type + annotator = self.annotators.get(expression.__class__) - return annotator(self, expression) if annotator else self._annotate_args(expression) + return ( + annotator(self, expression) + if annotator + else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) + ) def _annotate_args(self, expression): for value in expression.args.values(): for v in ensure_list(value): - self.annotate(v) + self._maybe_annotate(v) return expression - def _annotate_cast(self, expression): - expression.type = expression.args["to"].this - return self._annotate_args(expression) - - def _annotate_data_type(self, expression): - expression.type = expression.this - return self._annotate_args(expression) - def _maybe_coerce(self, type1, type2): + # We propagate the NULL / UNKNOWN types upwards if found + if exp.DataType.Type.NULL in (type1, type2): + return exp.DataType.Type.NULL + if exp.DataType.Type.UNKNOWN in (type1, type2): + return exp.DataType.Type.UNKNOWN + return type2 if type2 in self.coerces_to[type1] else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - if isinstance(expression, (exp.Condition, exp.Predicate)): + left_type = expression.left.type + right_type = expression.right.type + + if isinstance(expression, (exp.And, exp.Or)): + if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: + expression.type = exp.DataType.Type.NULL + elif exp.DataType.Type.NULL in (left_type, right_type): + expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + else: + expression.type = exp.DataType.Type.BOOLEAN + elif isinstance(expression, (exp.Condition, exp.Predicate)): expression.type = exp.DataType.Type.BOOLEAN else: - expression.type = self._maybe_coerce(expression.left.type, expression.right.type) + expression.type = self._maybe_coerce(left_type, right_type) return expression @@ -157,6 +273,6 @@ class TypeAnnotator: return expression - def _annotate_boolean(self, expression): - expression.type = exp.DataType.Type.BOOLEAN - return expression + def _annotate_with_type(self, expression, target_type): + expression.type = target_type + return self._annotate_args(expression) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index d29c22b..3e435f5 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - { "joins", "where", "order", + "hint", } @@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: inner_select = inner_scope.expression.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): - from_or_join = table.find_ancestor(exp.From, exp.Join) - + from_or_join = table.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): node_to_replace = table if isinstance(node_to_replace.parent, exp.Alias): node_to_replace = node_to_replace.parent alias = node_to_replace.alias else: alias = table.name + _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, node_to_replace, alias) _merge_expressions(outer_scope, inner_scope, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) return expression @@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): alias = subquery.alias_or_name - from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) @@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated): +def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. @@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): outer_scope (Scope) inner_select (exp.Select) leave_tables_isolated (bool) + from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ @@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) + and not ( + isinstance(from_or_join, exp.Join) + and inner_select.args.get("where") + and from_or_join.side in {"FULL", "LEFT", "RIGHT"} + ) + and not ( + isinstance(from_or_join, exp.From) + and inner_select.args.get("where") + and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + ) ) @@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ new_subquery = inner_scope.expression.args.get("from").expressions[0] node_to_replace.replace(new_subquery) + for join_hint in outer_scope.join_hints: + tables = join_hint.find_all(exp.Table) + for table in tables: + if table.alias_or_name == node_to_replace.alias_or_name: + new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery + table.set("this", exp.to_identifier(new_table.alias_or_name)) outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope): outer_scope.expression.set("order", inner_scope.expression.args.get("order")) +def _merge_hints(outer_scope, inner_scope): + inner_scope_hint = inner_scope.expression.args.get("hint") + if not inner_scope_hint: + return + outer_scope_hint = outer_scope.expression.args.get("hint") + if outer_scope_hint: + for hint_expression in inner_scope_hint.expressions: + outer_scope_hint.append("expressions", hint_expression) + else: + outer_scope.expression.set("hint", inner_scope_hint) + + def _pop_cte(inner_scope): """ Remove CTE from the AST. diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index a070d70..9c8d71d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from sqlglot import exp from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import traverse_scope @@ -20,22 +22,30 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - for scope in reversed(traverse_scope(expression)): + scope_ref_count = defaultdict(lambda: 0) + scopes = traverse_scope(expression) + scopes.reverse() + + for scope in scopes: + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + for scope in scopes: select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources) + pushdown(where.this, scope.selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself for join in select.args.get("joins") or []: name = join.this.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression -def pushdown(condition, sources): +def pushdown(condition, sources, scope_ref_count): if not condition: return @@ -45,17 +55,17 @@ def pushdown(condition, sources): predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: - pushdown_cnf(predicates, sources) + pushdown_cnf(predicates, sources, scope_ref_count) else: - pushdown_dnf(predicates, sources) + pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope): +def pushdown_cnf(predicates, scope, scope_ref_count): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ for predicate in predicates: - for node in nodes_for_predicate(predicate, scope).values(): + for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): predicate.replace(exp.TRUE) node.on(predicate, copy=False) @@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def pushdown_dnf(predicates, scope): +def pushdown_dnf(predicates, scope, scope_ref_count): """ If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form. @@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope): # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) for table in sorted(pushdown_tables): for predicate in predicates: - nodes = nodes_for_predicate(predicate, scope) + nodes = nodes_for_predicate(predicate, scope, scope_ref_count) if table not in nodes: continue @@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def nodes_for_predicate(predicate, sources): +def nodes_for_predicate(predicate, sources, scope_ref_count): nodes = {} tables = exp.column_table_names(predicate) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) @@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources): if node and where_condition: node = node.find_ancestor(exp.Join, exp.From) - # a node can reference a CTE which should be push down + # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): node = source.expression @@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: - if not node.args.get("group"): + # we can't push down predicates to select statements if they are referenced in + # multiple places. + if not node.args.get("group") and scope_ref_count[id(source)] < 2: nodes[table] = node return nodes diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 72ce256..7d77ef1 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -31,8 +31,8 @@ def qualify_columns(expression, schema): _pop_table_column_aliases(scope.derived_tables) _expand_using(scope, resolver) _expand_group_by(scope, resolver) - _expand_order_by(scope) _qualify_columns(scope, resolver) + _expand_order_by(scope) if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) @@ -235,7 +235,7 @@ def _expand_stars(scope, resolver): for table in tables: if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") - columns = resolver.get_source_columns(table) + columns = resolver.get_source_columns(table, only_visible=True) table_id = id(table) for name in columns: if name not in except_columns.get(table_id, set()): @@ -332,7 +332,7 @@ class _Resolver: self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) return self._all_columns - def get_source_columns(self, name): + def get_source_columns(self, name, only_visible=False): """Resolve the source columns for a given source `name`""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") @@ -342,7 +342,7 @@ class _Resolver: # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): try: - return self.schema.column_names(source) + return self.schema.column_names(source, only_visible) except Exception as e: raise OptimizeError(str(e)) from e diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 1bbd86a..d7743c9 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -9,16 +9,28 @@ class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def column_names(self, table): + def column_names(self, table, only_visible=False): """ Get the column names for a table. - Args: table (sqlglot.expressions.Table): Table expression instance + only_visible (bool): Whether to include invisible columns Returns: list[str]: list of column names """ + @abc.abstractmethod + def get_column_type(self, table, column): + """ + Get the exp.DataType type of a column in the schema. + + Args: + table (sqlglot.expressions.Table): The source table. + column (sqlglot.expressions.Column): The target column. + Returns: + sqlglot.expressions.DataType.Type: The resulting column type. + """ + class MappingSchema(Schema): """ @@ -29,10 +41,19 @@ class MappingSchema(Schema): 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} + visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema): + def __init__(self, schema, visible=None, dialect=None): self.schema = schema + self.visible = visible + self.dialect = dialect + self._type_mapping_cache = {} depth = _dict_depth(schema) @@ -49,7 +70,7 @@ class MappingSchema(Schema): self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) - def column_names(self, table): + def column_names(self, table, only_visible=False): if not isinstance(table.this, exp.Identifier): return fs_get(table) @@ -58,7 +79,39 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + if not only_visible or not self.visible: + return columns + + visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) + return [col for col in columns if col in visible] + + def get_column_type(self, table, column): + try: + schema_type = self.schema.get(table.name, {}).get(column.name).upper() + return self._convert_type(schema_type) + except: + raise OptimizeError(f"Failed to get type for column {column.sql()}") + + def _convert_type(self, schema_type): + """ + Convert a type represented as a string to the corresponding exp.DataType.Type object. + + Args: + schema_type (str): The type we want to convert. + Returns: + sqlglot.expressions.DataType.Type: The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + try: + self._type_mapping_cache[schema_type] = exp.maybe_parse( + schema_type, into=exp.DataType, dialect=self.dialect + ).this + except AttributeError: + raise OptimizeError(f"Failed to convert type {schema_type}") + + return self._type_mapping_cache[schema_type] def ensure_schema(schema): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 6332cdd..89de517 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -68,6 +68,7 @@ class Scope: self._selected_sources = None self._columns = None self._external_columns = None + self._join_hints = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -85,14 +86,17 @@ class Scope: self._subqueries = [] self._derived_tables = [] self._raw_columns = [] + self._join_hints = [] for node, parent, _ in self.walk(bfs=False): if node is self.expression: continue elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) - elif isinstance(node, exp.Table): + elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): self._tables.append(node) + elif isinstance(node, exp.JoinHint): + self._join_hints.append(node) elif isinstance(node, exp.UDTF): self._derived_tables.append(node) elif isinstance(node, exp.CTE): @@ -246,7 +250,7 @@ class Scope: table only becomes a selected source if it's included in a FROM or JOIN clause. Returns: - dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes + dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes """ if self._selected_sources is None: referenced_names = [] @@ -310,6 +314,18 @@ class Scope: self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns + @property + def join_hints(self): + """ + Hints that exist in the scope that reference tables + + Returns: + list[exp.JoinHint]: Join hints that are referenced within the scope + """ + if self._join_hints is None: + return [] + return self._join_hints + def source_columns(self, source_name): """ Get all columns in the current scope for a particular source. diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 319e6b6..c077906 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -56,12 +56,16 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): + if isinstance(expression.this, exp.Null): + return NULL if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + if isinstance(condition, exp.Null): + return NULL if always_true(expression.this): return FALSE if expression.this == FALSE: @@ -95,10 +99,10 @@ def simplify_connectors(expression): return left if isinstance(expression, exp.And): - if NULL in (left, right): - return NULL if FALSE in (left, right): return FALSE + if NULL in (left, right): + return NULL if always_true(left) and always_true(right): return TRUE if always_true(left): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5f20afc..c29e520 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType logger = logging.getLogger("sqlglot") +def parse_var_map(args): + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + return exp.VarMap( + keys=exp.Array(expressions=keys), + values=exp.Array(expressions=values), + ) + + class Parser: """ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` @@ -48,6 +60,7 @@ class Parser: start=exp.Literal.number(1), length=exp.Literal.number(10), ), + "VAR_MAP": parse_var_map, } NO_PAREN_FUNCTIONS = { @@ -117,6 +130,7 @@ class Parser: TokenType.VAR, TokenType.ALTER, TokenType.ALWAYS, + TokenType.ANTI, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, @@ -164,6 +178,7 @@ class Parser: TokenType.ROWS, TokenType.SCHEMA_COMMENT, TokenType.SEED, + TokenType.SEMI, TokenType.SET, TokenType.SHOW, TokenType.STABLE, @@ -273,6 +288,8 @@ class Parser: TokenType.INNER, TokenType.OUTER, TokenType.CROSS, + TokenType.SEMI, + TokenType.ANTI, } COLUMN_OPERATORS = { @@ -318,6 +335,8 @@ class Parser: exp.Properties: lambda self: self._parse_properties(), exp.Where: lambda self: self._parse_where(), exp.Ordered: lambda self: self._parse_ordered(), + exp.Having: lambda self: self._parse_having(), + exp.With: lambda self: self._parse_with(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -338,7 +357,6 @@ class Parser: TokenType.NULL: lambda *_: exp.Null(), TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), - TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), @@ -910,7 +928,20 @@ class Parser: return self.expression(exp.Tuple, expressions=expressions) def _parse_select(self, nested=False, table=False): - if self._match(TokenType.SELECT): + cte = self._parse_with() + if cte: + this = self._parse_statement() + + if not this: + self.raise_error("Failed to parse any statement following CTE") + return cte + + if "with" in this.arg_types: + this.set("with", cte) + else: + self.raise_error(f"{this.key} does not support CTE") + this = cte + elif self._match(TokenType.SELECT): hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -938,39 +969,6 @@ class Parser: if from_: this.set("from", from_) self._parse_query_modifiers(this) - elif self._match(TokenType.WITH): - recursive = self._match(TokenType.RECURSIVE) - - expressions = [] - - while True: - expressions.append(self._parse_cte()) - - if not self._match(TokenType.COMMA): - break - - cte = self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ) - this = self._parse_statement() - - if not this: - self.raise_error("Failed to parse any statement following CTE") - return cte - - if "with" in this.arg_types: - this.set( - "with", - self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ), - ) - else: - self.raise_error(f"{this.key} does not support CTE") elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) @@ -986,6 +984,26 @@ class Parser: return self._parse_set_operations(this) if this else None + def _parse_with(self): + if not self._match(TokenType.WITH): + return None + + recursive = self._match(TokenType.RECURSIVE) + + expressions = [] + + while True: + expressions.append(self._parse_cte()) + + if not self._match(TokenType.COMMA): + break + + return self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ) + def _parse_cte(self): alias = self._parse_table_alias() if not alias or not alias.this: @@ -1485,8 +1503,7 @@ class Parser: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) - else: - self._match_l_paren() + elif self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_select_or_expression) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): @@ -1495,6 +1512,9 @@ class Parser: this = self.expression(exp.In, this=this, expressions=expressions) self._match_r_paren() + else: + this = self.expression(exp.In, this=this, field=self._parse_field()) + return this def _parse_between(self, this): @@ -1591,7 +1611,7 @@ class Parser: elif nested: expressions = self._parse_csv(self._parse_types) else: - expressions = self._parse_csv(self._parse_number) + expressions = self._parse_csv(self._parse_type) if not expressions: self._retreat(index) @@ -1706,7 +1726,7 @@ class Parser: def _parse_field(self, any_token=False): return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) - def _parse_function(self): + def _parse_function(self, functions=None): if not self._curr: return None @@ -1742,7 +1762,9 @@ class Parser: self._match_r_paren() return this - function = self.FUNCTIONS.get(upper) + if functions is None: + functions = self.FUNCTIONS + function = functions.get(upper) args = self._parse_csv(self._parse_lambda) if function: @@ -2025,10 +2047,20 @@ class Parser: return self.expression(exp.Cast, this=this, to=to) def _parse_position(self): - substr = self._parse_bitwise() + args = self._parse_csv(self._parse_bitwise) + if self._match(TokenType.IN): - string = self._parse_bitwise() - return self.expression(exp.StrPosition, this=string, substr=substr) + args.append(self._parse_bitwise()) + + # Note: we're parsing in order needle, haystack, position + this = exp.StrPosition.from_arg_list(args) + self.validate_expression(this, args) + + return this + + def _parse_join_hint(self, func_name): + args = self._parse_csv(self._parse_table) + return exp.JoinHint(this=func_name.upper(), expressions=args) def _parse_substring(self): # Postgres supports the form: substring(string [from int] [for int]) @@ -2247,6 +2279,9 @@ class Parser: def _parse_placeholder(self): if self._match(TokenType.PLACEHOLDER): return exp.Placeholder() + elif self._match(TokenType.COLON): + self._advance() + return exp.Placeholder(this=self._prev.text) return None def _parse_except(self): diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 39bf421..17c038c 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -104,6 +104,7 @@ class TokenType(AutoName): ALL = auto() ALTER = auto() ANALYZE = auto() + ANTI = auto() ANY = auto() ARRAY = auto() ASC = auto() @@ -236,6 +237,7 @@ class TokenType(AutoName): SCHEMA_COMMENT = auto() SEED = auto() SELECT = auto() + SEMI = auto() SEPARATOR = auto() SET = auto() SHOW = auto() @@ -262,6 +264,7 @@ class TokenType(AutoName): USE = auto() USING = auto() VALUES = auto() + VACUUM = auto() VIEW = auto() VOLATILE = auto() WHEN = auto() @@ -406,6 +409,7 @@ class Tokenizer(metaclass=_Tokenizer): "ALTER": TokenType.ALTER, "ANALYZE": TokenType.ANALYZE, "AND": TokenType.AND, + "ANTI": TokenType.ANTI, "ANY": TokenType.ANY, "ASC": TokenType.ASC, "AS": TokenType.ALIAS, @@ -528,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer): "ROWS": TokenType.ROWS, "SEED": TokenType.SEED, "SELECT": TokenType.SELECT, + "SEMI": TokenType.SEMI, "SET": TokenType.SET, "SHOW": TokenType.SHOW, "SOME": TokenType.SOME, @@ -551,6 +556,7 @@ class Tokenizer(metaclass=_Tokenizer): "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, + "VACUUM": TokenType.VACUUM, "VALUES": TokenType.VALUES, "VIEW": TokenType.VIEW, "VOLATILE": TokenType.VOLATILE, @@ -577,6 +583,7 @@ class Tokenizer(metaclass=_Tokenizer): "INT8": TokenType.BIGINT, "DECIMAL": TokenType.DECIMAL, "MAP": TokenType.MAP, + "NULLABLE": TokenType.NULLABLE, "NUMBER": TokenType.DECIMAL, "NUMERIC": TokenType.DECIMAL, "FIXED": TokenType.DECIMAL, @@ -629,6 +636,7 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.SHOW, TokenType.TRUNCATE, TokenType.USE, + TokenType.VACUUM, } # handle numeric literals like in hive (3L = BIGINT) -- cgit v1.2.3