summaryrefslogtreecommitdiffstats
path: root/sqlglot/planner.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/planner.py')
-rw-r--r--sqlglot/planner.py40
1 files changed, 30 insertions, 10 deletions
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 4ed7449..f246702 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -91,6 +91,7 @@ class Step:
A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
+ expression = expression.unnest()
with_ = expression.args.get("with")
# CTEs break the mold of scope and introduce themselves to all in the context.
@@ -120,22 +121,25 @@ class Step:
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
- aggregations = []
+ aggregations = set()
next_operand_name = name_sequence("_a_")
def extract_agg_operands(expression):
- for agg in expression.find_all(exp.AggFunc):
+ agg_funcs = tuple(expression.find_all(exp.AggFunc))
+ if agg_funcs:
+ aggregations.add(expression)
+ for agg in agg_funcs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = next_operand_name()
operand.replace(exp.column(operands[operand], quoted=True))
+ return bool(agg_funcs)
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
- aggregations.append(e)
extract_agg_operands(e)
else:
projections.append(e)
@@ -155,22 +159,38 @@ class Step:
having = expression.args.get("having")
if having:
- extract_agg_operands(having)
- aggregate.condition = having.this
+ if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
+ aggregate.condition = exp.column("_h", step.name, quoted=True)
+ else:
+ aggregate.condition = having.this
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
- aggregate.aggregations = aggregations
+ aggregate.aggregations = list(aggregations)
+
# give aggregates names and replace projections with references to them
aggregate.group = {
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
}
+
+ intermediate: t.Dict[str | exp.Expression, str] = {}
+ for k, v in aggregate.group.items():
+ intermediate[v] = k
+ if isinstance(v, exp.Column):
+ intermediate[v.alias_or_name] = k
+
for projection in projections:
- for i, e in aggregate.group.items():
- for child, *_ in projection.walk():
- if child == e:
- child.replace(exp.column(i, step.name))
+ for node, *_ in projection.walk():
+ name = intermediate.get(node)
+ if name:
+ node.replace(exp.column(name, step.name))
+ if aggregate.condition:
+ for node, *_ in aggregate.condition.walk():
+ name = intermediate.get(node) or intermediate.get(node.name)
+ if name:
+ node.replace(exp.column(name, step.name))
+
aggregate.add_dependency(step)
step = aggregate