From 49af28576db02470fe1d2de04e3901309b60c2e4 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 10 Jul 2023 07:36:29 +0200 Subject: Merging upstream version 17.3.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 36 ++++++++++++++++++++++++------------ sqlglot/dialects/hive.py | 1 + sqlglot/dialects/postgres.py | 1 + sqlglot/dialects/spark.py | 6 ++++++ sqlglot/dialects/spark2.py | 3 --- sqlglot/dialects/sqlite.py | 2 +- sqlglot/expressions.py | 16 +++++++++++++--- sqlglot/generator.py | 10 +++++++--- sqlglot/optimizer/qualify_tables.py | 34 +++++++++++++++++++--------------- sqlglot/optimizer/scope.py | 8 +++----- sqlglot/parser.py | 30 ++++++++++++++++++++++++++---- 11 files changed, 101 insertions(+), 46 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 82162b4..35892f7 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -174,6 +174,12 @@ def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts: return expr_type.from_arg_list(args) +def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: + # TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation + arg = seq_get(args, 0) + return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg) + + class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True @@ -275,6 +281,8 @@ class BigQuery(Dialect): "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, + "MD5": exp.MD5Digest.from_arg_list, + "TO_HEX": _parse_to_hex, "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), @@ -379,22 +387,27 @@ class BigQuery(Dialect): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), + exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: _date_add_sql("DATE", "ADD"), + exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateFromParts: rename_func("DATE"), + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), - exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", - exp.DateStrToDate: datestrtodate_sql, exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), - exp.JSONFormat: rename_func("TO_JSON_STRING"), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GroupConcat: rename_func("STRING_AGG"), + exp.Hex: rename_func("TO_HEX"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.JSONFormat: rename_func("TO_JSON_STRING"), exp.Max: max_or_greatest, + exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), + exp.MD5Digest: rename_func("MD5"), exp.Min: min_or_least, + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.RegexpExtract: lambda self, e: self.func( "REGEXP_EXTRACT", e.this, @@ -403,6 +416,7 @@ class BigQuery(Dialect): e.args.get("occurrence"), ), exp.RegexpLike: rename_func("REGEXP_CONTAINS"), + exp.ReturnsProperty: _returnsproperty_sql, exp.Select: transforms.preprocess( [ transforms.explode_to_unnest, @@ -411,6 +425,9 @@ class BigQuery(Dialect): _alias_ordered_group, ] ), + exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" + if e.name == "IMMUTABLE" + else "NOT DETERMINISTIC", exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") @@ -420,17 +437,12 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), + exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.VariancePop: rename_func("VAR_POP"), + exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), + exp.Unhex: rename_func("FROM_HEX"), exp.Values: _derived_table_values_to_unnest, - exp.ReturnsProperty: _returnsproperty_sql, - exp.Create: _create_sql, - exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), - exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" - if e.name == "IMMUTABLE" - else "NOT DETERMINISTIC", + exp.VariancePop: rename_func("VAR_POP"), } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 1abc0f4..5762efb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -357,6 +357,7 @@ class Hive(Dialect): exp.Left: left_to_substring_sql, exp.Map: var_map_sql, exp.Max: max_or_greatest, + exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 6d78a07..7706456 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -263,6 +263,7 @@ class Postgres(Dialect): "DO": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, "JSONB": TokenType.JSONB, + "MONEY": TokenType.MONEY, "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 7a7ee01..73f4370 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -41,6 +41,12 @@ class Spark(Spark2): } class Generator(Spark2.Generator): + TYPE_MAPPING = { + **Spark2.Generator.TYPE_MAPPING, + exp.DataType.Type.MONEY: "DECIMAL(15, 4)", + exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)", + exp.DataType.Type.UNIQUEIDENTIFIER: "STRING", + } TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index afe2482..f909e8c 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -177,9 +177,6 @@ class Spark2(Hive): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, - exp.DataType.Type.TINYINT: "BYTE", - exp.DataType.Type.SMALLINT: "SHORT", - exp.DataType.Type.BIGINT: "LONG", } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 5ded6df..90b774e 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -192,7 +192,7 @@ class SQLite(Dialect): if len(expression.expressions) > 1: return rename_func("MIN")(self, expression) - return self.expressions(expression) + return self.sql(expression, "this") def transaction_sql(self, expression: exp.Transaction) -> str: this = expression.this diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index fdf02c8..242e66c 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -274,12 +274,16 @@ class Expression(metaclass=_Expression): def set(self, arg_key: str, value: t.Any) -> None: """ - Sets `arg_key` to `value`. + Sets arg_key to value. Args: - arg_key (str): name of the expression arg. + arg_key: name of the expression arg. value: value to set the arg to. """ + if value is None: + self.args.pop(arg_key, None) + return + self.args[arg_key] = value self._set_parent(arg_key, value) @@ -2278,6 +2282,7 @@ class Table(Expression): "pivots": False, "hints": False, "system_time": False, + "wrapped": False, } @property @@ -4249,7 +4254,7 @@ class JSONArrayContains(Binary, Predicate, Func): class Least(Func): - arg_types = {"expressions": False} + arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -4342,6 +4347,11 @@ class MD5(Func): _sql_names = ["MD5"] +# Represents the variant of the MD5 function that returns a binary value +class MD5Digest(Func): + _sql_names = ["MD5_DIGEST"] + + class Min(AggFunc): arg_types = {"this": True, "expressions": False} is_var_len_args = True diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 1ce2aaa..4ac988f 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1215,7 +1215,8 @@ class Generator: system_time = expression.args.get("system_time") system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" + sql = f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" + return f"({sql})" if expression.args.get("wrapped") else sql def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -2289,11 +2290,14 @@ class Generator: def function_fallback_sql(self, expression: exp.Func) -> str: args = [] - for arg_value in expression.args.values(): + + for key in expression.arg_types: + arg_value = expression.args.get(key) + if isinstance(arg_value, list): for value in arg_value: args.append(value) - else: + elif arg_value is not None: args.append(arg_value) return self.func(expression.sql_name(), *args) diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 9c931d6..af8c716 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -15,8 +15,7 @@ def qualify_tables( schema: t.Optional[Schema] = None, ) -> E: """ - Rewrite sqlglot AST to have fully qualified tables. Additionally, this - replaces "join constructs" (*) by equivalent SELECT * subqueries. + Rewrite sqlglot AST to have fully qualified, unnested tables. Examples: >>> import sqlglot @@ -24,9 +23,18 @@ def qualify_tables( >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> + >>> expression = sqlglot.parse_one("SELECT * FROM (tbl)") + >>> qualify_tables(expression).sql() + 'SELECT * FROM tbl AS tbl' + >>> >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") >>> qualify_tables(expression).sql() - 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' + 'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2' + + Note: + This rule effectively enforces a left-to-right join order, since all joins + are unnested. This means that the optimizer doesn't necessarily preserve the + original join order, e.g. when parentheses are used to specify it explicitly. Args: expression: Expression to qualify @@ -36,19 +44,11 @@ def qualify_tables( Returns: The qualified expression. - - (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html """ next_alias_name = name_sequence("_q_") for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): - # Expand join construct - if isinstance(derived_table, exp.Subquery): - unnested = derived_table.unnest() - if isinstance(unnested, exp.Table): - derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) - if not derived_table.args.get("alias"): alias_ = next_alias_name() derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) @@ -66,13 +66,17 @@ def qualify_tables( if not source.args.get("catalog"): source.set("catalog", exp.to_identifier(catalog)) + # Unnest joins attached in tables by appending them to the closest query + for join in source.args.get("joins") or []: + scope.expression.append("joins", join) + + source.set("joins", None) + source.set("wrapped", None) + if not source.alias: source = source.replace( alias( - source, - name or source.name or next_alias_name(), - copy=True, - table=True, + source, name or source.name or next_alias_name(), copy=True, table=True ) ) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index bc649e4..7dcfb37 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -548,9 +548,6 @@ def _traverse_scope(scope): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) - elif isinstance(scope.expression, exp.Table): - # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..) - yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): pass else: @@ -632,8 +629,9 @@ def _traverse_tables(scope): if from_: expressions.append(from_.this) - for join in scope.expression.args.get("joins") or []: - expressions.append(join.this) + for expression in (scope.expression, *scope.find_all(exp.Table)): + for join in expression.args.get("joins") or []: + expressions.append(join.this) if isinstance(scope.expression, exp.Table): expressions.append(scope.expression) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f7fd6ba..c7f4050 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1969,10 +1969,31 @@ class Parser(metaclass=_Parser): self._match_r_paren() - # early return so that subquery unions aren't parsed again - # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 - # Union ALL should be a property of the top select node, not the subquery - return self._parse_subquery(this, parse_alias=parse_subquery_alias) + alias = None + + # Ensure "wrapped" tables are not parsed as Subqueries. The exception to this is when there's + # an alias that can be applied to the parentheses, because that would shadow all wrapped table + # names, and so we want to parse it as a Subquery to represent the inner scope appropriately. + # Additionally, we want the node under the Subquery to be an actual query, so we will replace + # the table reference with a star query that selects from it. + if isinstance(this, exp.Table): + alias = self._parse_table_alias() + if not alias: + this.set("wrapped", True) + return this + + this.set("wrapped", None) + joins = this.args.pop("joins", None) + this = this.replace(exp.select("*").from_(this.copy(), copy=False)) + this.set("joins", joins) + + subquery = self._parse_subquery(this, parse_alias=parse_subquery_alias and not alias) + if subquery and alias: + subquery.set("alias", alias) + + # We return early here so that the UNION isn't attached to the subquery by the + # following call to _parse_set_operations, but instead becomes the parent node + return subquery elif self._match(TokenType.VALUES): this = self.expression( exp.Values, @@ -2292,6 +2313,7 @@ class Parser(metaclass=_Parser): else: joins = None self._retreat(index) + kwargs["this"].set("joins", joins) return self.expression(exp.Join, **kwargs) -- cgit v1.2.3