diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-29 13:02:29 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-29 13:02:29 +0000 |
commit | 9b39dac84e82bf473216939e50b8836170f01d23 (patch) | |
tree | 9b405bc86ef7e2ea28cddc6b787ed70355cf7fce /sqlglot/planner.py | |
parent | Releasing debian version 16.4.2-1. (diff) | |
download | sqlglot-9b39dac84e82bf473216939e50b8836170f01d23.tar.xz sqlglot-9b39dac84e82bf473216939e50b8836170f01d23.zip |
Merging upstream version 16.7.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/planner.py')
-rw-r--r-- | sqlglot/planner.py | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 4ed7449..f246702 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -91,6 +91,7 @@ class Step: A Step DAG corresponding to `expression`. """ ctes = ctes or {} + expression = expression.unnest() with_ = expression.args.get("with") # CTEs break the mold of scope and introduce themselves to all in the context. @@ -120,22 +121,25 @@ class Step: projections = [] # final selects in this chain of steps representing a select operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) - aggregations = [] + aggregations = set() next_operand_name = name_sequence("_a_") def extract_agg_operands(expression): - for agg in expression.find_all(exp.AggFunc): + agg_funcs = tuple(expression.find_all(exp.AggFunc)) + if agg_funcs: + aggregations.add(expression) + for agg in agg_funcs: for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = next_operand_name() operand.replace(exp.column(operands[operand], quoted=True)) + return bool(agg_funcs) for e in expression.expressions: if e.find(exp.AggFunc): projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) - aggregations.append(e) extract_agg_operands(e) else: projections.append(e) @@ -155,22 +159,38 @@ class Step: having = expression.args.get("having") if having: - extract_agg_operands(having) - aggregate.condition = having.this + if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)): + aggregate.condition = exp.column("_h", step.name, quoted=True) + else: + aggregate.condition = having.this aggregate.operands = tuple( alias(operand, alias_) for operand, alias_ in operands.items() ) - aggregate.aggregations = aggregations + aggregate.aggregations = list(aggregations) + # give aggregates names and replace projections with references to them aggregate.group = { f"_g{i}": e for i, e in enumerate(group.expressions if group else []) } + + intermediate: t.Dict[str | exp.Expression, str] = {} + for k, v in aggregate.group.items(): + intermediate[v] = k + if isinstance(v, exp.Column): + intermediate[v.alias_or_name] = k + for projection in projections: - for i, e in aggregate.group.items(): - for child, *_ in projection.walk(): - if child == e: - child.replace(exp.column(i, step.name)) + for node, *_ in projection.walk(): + name = intermediate.get(node) + if name: + node.replace(exp.column(name, step.name)) + if aggregate.condition: + for node, *_ in aggregate.condition.walk(): + name = intermediate.get(node) or intermediate.get(node.name) + if name: + node.replace(exp.column(name, step.name)) + aggregate.add_dependency(step) step = aggregate |