From 374a0f6318bcf423b1b784d30b25a8327c65cb24 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 6 Jul 2023 09:28:12 +0200 Subject: Merging upstream version 17.2.0. Signed-off-by: Daniel Baumann --- sqlglot/planner.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) (limited to 'sqlglot/planner.py') diff --git a/sqlglot/planner.py b/sqlglot/planner.py index f246702..07ee739 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -23,9 +23,11 @@ class Plan: while nodes: node = nodes.pop() dag[node] = set() + for dep in node.dependencies: dag[node].add(dep) nodes.add(dep) + self._dag = dag return self._dag @@ -128,15 +130,22 @@ class Step: agg_funcs = tuple(expression.find_all(exp.AggFunc)) if agg_funcs: aggregations.add(expression) + for agg in agg_funcs: for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = next_operand_name() + operand.replace(exp.column(operands[operand], quoted=True)) + return bool(agg_funcs) + def set_ops_and_aggs(step): + step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) + step.aggregations = list(aggregations) + for e in expression.expressions: if e.find(exp.AggFunc): projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) @@ -164,10 +173,7 @@ class Step: else: aggregate.condition = having.this - aggregate.operands = tuple( - alias(operand, alias_) for operand, alias_ in operands.items() - ) - aggregate.aggregations = list(aggregations) + set_ops_and_aggs(aggregate) # give aggregates names and replace projections with references to them aggregate.group = { @@ -178,13 +184,14 @@ class Step: for k, v in aggregate.group.items(): intermediate[v] = k if isinstance(v, exp.Column): - intermediate[v.alias_or_name] = k + intermediate[v.name] = k for projection in projections: for node, *_ in projection.walk(): name = intermediate.get(node) if name: node.replace(exp.column(name, step.name)) + if aggregate.condition: for node, *_ in aggregate.condition.walk(): name = intermediate.get(node) or intermediate.get(node.name) @@ -197,6 +204,13 @@ class Step: order = expression.args.get("order") if order: + if isinstance(step, Aggregate): + for i, ordered in enumerate(order.expressions): + if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)): + ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True)) + + set_ops_and_aggs(aggregate) + sort = Sort() sort.name = step.name sort.key = order.expressions @@ -340,7 +354,10 @@ class Join(Step): def _to_s(self, indent: str) -> t.List[str]: lines = [] for name, join in self.joins.items(): - lines.append(f"{indent}{name}: {join['side']}") + lines.append(f"{indent}{name}: {join['side'] or 'INNER'}") + join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or [])) + if join_key: + lines.append(f"{indent}Key: {join_key}") if join.get("condition"): lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore return lines -- cgit v1.2.3