From ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 7 Mar 2023 19:09:31 +0100 Subject: Merging upstream version 11.3.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/merge_subqueries.py | 24 +++++++++++++++++++----- sqlglot/optimizer/pushdown_projections.py | 6 ++++-- 2 files changed, 23 insertions(+), 7 deletions(-) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 16aaf17..70172f4 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -314,13 +314,27 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if not where or not where.this: return + expression = outer_scope.expression + if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause - from_or_join.on(where.this, copy=False) - from_or_join.set("on", simplify(from_or_join.args.get("on"))) - else: - outer_scope.expression.where(where.this, copy=False) - outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) + # if it only has columns that are already joined + from_ = expression.args.get("from") + sources = {table.alias_or_name for table in from_.expressions} if from_ else {} + + for join in expression.args["joins"]: + source = join.alias_or_name + sources.add(source) + if source == from_or_join.alias_or_name: + break + + if set(exp.column_table_names(where.this)) <= sources: + from_or_join.on(where.this, copy=False) + from_or_join.set("on", simplify(from_or_join.args.get("on"))) + return + + expression.where(where.this, copy=False) + expression.set("where", simplify(expression.args.get("where"))) def _merge_order(outer_scope, inner_scope): diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 3f360f9..07a1b70 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -13,7 +13,7 @@ SELECT_ALL = object() DEFAULT_SELECTION = lambda: alias("1", "_") -def pushdown_projections(expression, schema=None): +def pushdown_projections(expression, schema=None, remove_unused_selections=True): """ Rewrite sqlglot AST to remove unused columns projections. @@ -26,6 +26,7 @@ def pushdown_projections(expression, schema=None): Args: expression (sqlglot.Expression): expression to optimize + remove_unused_selections (bool): remove selects that are unused Returns: sqlglot.Expression: optimized expression """ @@ -57,7 +58,8 @@ def pushdown_projections(expression, schema=None): ] if isinstance(scope.expression, exp.Select): - _remove_unused_selections(scope, parent_selections, schema) + if remove_unused_selections: + _remove_unused_selections(scope, parent_selections, schema) # Group columns by source name selects = defaultdict(set) -- cgit v1.2.3