From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/unnest_subqueries.py | 220 +++++++++++++++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 sqlglot/optimizer/unnest_subqueries.py (limited to 'sqlglot/optimizer/unnest_subqueries.py') diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py new file mode 100644 index 0000000..55c81c5 --- /dev/null +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -0,0 +1,220 @@ +import itertools + +from sqlglot import exp +from sqlglot.optimizer.scope import traverse_scope + + +def unnest_subqueries(expression): + """ + Rewrite sqlglot AST to convert some predicates with subqueries into joins. + + Convert the subquery into a group by so it is not a many to many left join. + Unnesting can only occur if the subquery does not have LIMIT or OFFSET. + Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") + >>> unnest_subqueries(expression).sql() + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ + AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + + Args: + expression (sqlglot.Expression): expression to unnest + Returns: + sqlglot.Expression: unnested expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + select = scope.expression + parent = select.parent_select + if scope.external_columns: + decorrelate(select, parent, scope.external_columns, sequence) + else: + unnest(select, parent, sequence) + + return expression + + +def unnest(select, parent_select, sequence): + predicate = select.find_ancestor(exp.In, exp.Any) + + if not predicate or parent_select is not predicate.parent_select: + return + + if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + return + + if isinstance(predicate, exp.Any): + predicate = predicate.find_ancestor(exp.EQ) + + if not predicate or parent_select is not predicate.parent_select: + return + + column = _other_operand(predicate) + value = select.selects[0] + alias = _alias(sequence) + + on = exp.condition(f'{column} = "{alias}"."{value.alias}"') + _replace(predicate, f"NOT {on.right} IS NULL") + + parent_select.join( + select.group_by(value.this, copy=False), + on=on, + join_type="LEFT", + join_alias=alias, + copy=False, + ) + + +def decorrelate(select, parent_select, external_columns, sequence): + where = select.args.get("where") + + if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): + return + + table_alias = _alias(sequence) + keys = [] + + # for all external columns in the where statement, + # split out the relevant data to convert it into a join + for column in external_columns: + if column.find_ancestor(exp.Where) is not where: + return + + predicate = column.find_ancestor(exp.Predicate) + + if not predicate or predicate.find_ancestor(exp.Where) is not where: + return + + if isinstance(predicate, exp.Binary): + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) + else: + return + + keys.append((key, column, predicate)) + + if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): + return + + value = select.selects[0] + key_aliases = {} + group_by = [] + + for key, _, predicate in keys: + # if we filter on the value of the subquery, it needs to be unique + if key == value.this: + key_aliases[key] = value.alias + group_by.append(key) + else: + if key not in key_aliases: + key_aliases[key] = _alias(sequence) + # all predicates that are equalities must also be in the unique + # so that we don't do a many to many join + if isinstance(predicate, exp.EQ) and key not in group_by: + group_by.append(key) + + parent_predicate = select.find_ancestor(exp.Predicate) + + # if the value of the subquery is not an agg or a key, we need to collect it into an array + # so that it can be grouped + if not value.find(exp.AggFunc) and value.this not in group_by: + select.select( + f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False + ) + + # exists queries should not have any selects as it only checks if there are any rows + # all selects will be added by the optimizer and only used for join keys + if isinstance(parent_predicate, exp.Exists): + select.args["expressions"] = [] + + for key, alias in key_aliases.items(): + if key in group_by: + # add all keys to the projections of the subquery + # so that we can use it as a join key + if isinstance(parent_predicate, exp.Exists) or key != value.this: + select.select(f"{key} AS {alias}", copy=False) + else: + select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + + alias = exp.column(value.alias, table_alias) + other = _other_operand(parent_predicate) + + if isinstance(parent_predicate, exp.Exists): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") + else: + parent_predicate = _replace(parent_predicate, "TRUE") + elif isinstance(parent_predicate, exp.All): + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.Any): + if value.this in group_by: + parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.In): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, + f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", + ) + else: + select.parent.replace(alias) + + for key, column, predicate in keys: + predicate.replace(exp.TRUE) + nested = exp.column(key_aliases[key], table_alias) + + if key in group_by: + key.replace(nested) + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) + elif isinstance(predicate, exp.EQ): + parent_predicate = _replace( + parent_predicate, + f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", + ) + else: + key.replace(exp.to_identifier("_x")) + parent_predicate = _replace( + parent_predicate, + f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', + ) + + parent_select.join( + select.group_by(*group_by, copy=False), + on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], + join_type="LEFT", + join_alias=table_alias, + copy=False, + ) + + +def _alias(sequence): + return f"_u_{next(sequence)}" + + +def _replace(expression, condition): + return expression.replace(exp.condition(condition)) + + +def _other_operand(expression): + if isinstance(expression, exp.In): + return expression.this + + if isinstance(expression, exp.Binary): + return expression.right if expression.arg_key == "this" else expression.left + + return None -- cgit v1.2.3