From beba715b97dd2349e01dde9b077d2535680ebdca Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 10 May 2023 08:44:58 +0200 Subject: Merging upstream version 12.2.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/eliminate_joins.py | 2 +- sqlglot/optimizer/expand_laterals.py | 4 ++-- sqlglot/optimizer/normalize.py | 2 ++ sqlglot/optimizer/optimizer.py | 2 -- sqlglot/optimizer/qualify_columns.py | 26 +++++++++++++++++++------- sqlglot/optimizer/qualify_tables.py | 18 ++++++++++++++++-- sqlglot/optimizer/scope.py | 6 ++++++ sqlglot/optimizer/simplify.py | 17 +++++++++++++---- 8 files changed, 59 insertions(+), 18 deletions(-) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index e0ddfa2..27de9c7 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -153,7 +153,7 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): - on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) for condition in on.flatten(): if isinstance(condition, exp.EQ): diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py index 59f3fec..5b2f706 100644 --- a/sqlglot/optimizer/expand_laterals.py +++ b/sqlglot/optimizer/expand_laterals.py @@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression: for column in projection.find_all(exp.Column): if not column.table and column.name in alias_to_expression: column.replace(alias_to_expression[column.name].copy()) - if isinstance(projection, exp.Alias): - alias_to_expression[projection.alias] = projection.this + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = projection.this return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 40668ef..b013312 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache): lambda c: to_func( uniq_sort(flatten(from_func(c, b.left)), cache), uniq_sort(flatten(from_func(c, b.right)), cache), + copy=False, ), ) else: a = to_func( uniq_sort(flatten(from_func(a, b.left)), cache), uniq_sort(flatten(from_func(a, b.right)), cache), + copy=False, ) return a diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 62eb11e..c165ffe 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries -from sqlglot.optimizer.expand_laterals import expand_laterals from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.lower_identities import lower_identities @@ -30,7 +29,6 @@ RULES = ( qualify_tables, isolate_table_selects, qualify_columns, - expand_laterals, pushdown_projections, validate_qualify_columns, normalize, diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 0a31246..6ac39f0 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -3,11 +3,12 @@ import typing as t from sqlglot import alias, exp from sqlglot.errors import OptimizeError +from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema -def qualify_columns(expression, schema): +def qualify_columns(expression, schema, expand_laterals=True): """ Rewrite sqlglot AST to have fully qualified columns. @@ -26,6 +27,9 @@ def qualify_columns(expression, schema): """ schema = ensure_schema(schema) + if not schema.mapping and expand_laterals: + expression = _expand_laterals(expression) + for scope in traverse_scope(expression): resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) @@ -39,6 +43,9 @@ def qualify_columns(expression, schema): _expand_group_by(scope, resolver) _expand_order_by(scope) + if schema.mapping and expand_laterals: + expression = _expand_laterals(expression) + return expression @@ -124,7 +131,7 @@ def _expand_using(scope, resolver): tables[join_table] = None join.args.pop("using") - join.set("on", exp.and_(*conditions)) + join.set("on", exp.and_(*conditions, copy=False)) if column_tables: for column in scope.columns: @@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver): # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", column_table) - elif column_table not in scope.sources: + elif column_table not in scope.sources and ( + not scope.parent or column_table not in scope.parent.sources + ): # structs are used like tables (e.g. "struct"."field"), so they need to be qualified # separately and represented as dot(dot(...(., field1), field2, ...)) @@ -376,10 +385,13 @@ def _qualify_outputs(scope): if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias) and not selection.is_star: - alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") - alias_.set("this", selection) - selection = alias_ - + selection = alias( + selection, + alias=selection.output_name or f"_col_{i}", + quoted=True + if isinstance(selection, exp.Column) and selection.this.quoted + else None, + ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index a719ebe..1b451a6 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope def qualify_tables(expression, db=None, catalog=None, schema=None): """ - Rewrite sqlglot AST to have fully qualified tables. + Rewrite sqlglot AST to have fully qualified tables. Additionally, this + replaces "join constructs" (*) by equivalent SELECT * subqueries. - Example: + Examples: >>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' + >>> + >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") + >>> qualify_tables(expression).sql() + 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' Args: expression (sqlglot.Expression): expression to qualify db (str): Database name catalog (str): Catalog name schema: A schema to populate + Returns: sqlglot.Expression: qualified expression + + (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html """ sequence = itertools.count() @@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): + # Expand join construct + if isinstance(derived_table, exp.Subquery): + unnested = derived_table.unnest() + if isinstance(unnested, exp.Table): + derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) + if not derived_table.args.get("alias"): alias_ = f"_q_{next(sequence)}" derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b582eb0..e00b3c9 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -510,6 +510,9 @@ def _traverse_scope(scope): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) + elif isinstance(scope.expression, exp.Table): + # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..) + yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): pass else: @@ -587,6 +590,9 @@ def _traverse_tables(scope): for join in scope.expression.args.get("joins") or []: expressions.append(join.this) + if isinstance(scope.expression, exp.Table): + expressions.append(scope.expression) + expressions.extend(scope.expression.args.get("laterals") or []) for expression in expressions: diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 4e6c910..0904189 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression: return exp.and_( exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), + copy=False, ) return expression @@ -76,9 +77,17 @@ def simplify_not(expression): if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): - return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) + return exp.or_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ) if isinstance(condition, exp.Or): - return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + return exp.and_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ) if is_null(condition): return exp.null() if always_true(expression.this): @@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True): # A AND C AND B -> A AND B AND C for i, (sql, e) in enumerate(arr[1:]): if sql < arr[i][0]: - expression = result_func(*(e for _, e in sorted(arr))) + expression = result_func(*(e for _, e in sorted(arr)), copy=False) break else: # we didn't have to sort but maybe we need to dedup if len(deduped) < len(flattened): - expression = result_func(*deduped.values()) + expression = result_func(*deduped.values(), copy=False) return expression -- cgit v1.2.3