summaryrefslogtreecommitdiffstats
path: root/sqlglot/planner.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-07-06 07:28:12 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-07-06 07:28:12 +0000
commit374a0f6318bcf423b1b784d30b25a8327c65cb24 (patch)
tree9303a1cbdba85b5d9781ebef32eb1902d3790c99 /sqlglot/planner.py
parentReleasing debian version 16.7.7-1. (diff)
downloadsqlglot-374a0f6318bcf423b1b784d30b25a8327c65cb24.tar.xz
sqlglot-374a0f6318bcf423b1b784d30b25a8327c65cb24.zip
Merging upstream version 17.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/planner.py')
-rw-r--r--sqlglot/planner.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index f246702..07ee739 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -23,9 +23,11 @@ class Plan:
while nodes:
node = nodes.pop()
dag[node] = set()
+
for dep in node.dependencies:
dag[node].add(dep)
nodes.add(dep)
+
self._dag = dag
return self._dag
@@ -128,15 +130,22 @@ class Step:
agg_funcs = tuple(expression.find_all(exp.AggFunc))
if agg_funcs:
aggregations.add(expression)
+
for agg in agg_funcs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = next_operand_name()
+
operand.replace(exp.column(operands[operand], quoted=True))
+
return bool(agg_funcs)
+ def set_ops_and_aggs(step):
+ step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
+ step.aggregations = list(aggregations)
+
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
@@ -164,10 +173,7 @@ class Step:
else:
aggregate.condition = having.this
- aggregate.operands = tuple(
- alias(operand, alias_) for operand, alias_ in operands.items()
- )
- aggregate.aggregations = list(aggregations)
+ set_ops_and_aggs(aggregate)
# give aggregates names and replace projections with references to them
aggregate.group = {
@@ -178,13 +184,14 @@ class Step:
for k, v in aggregate.group.items():
intermediate[v] = k
if isinstance(v, exp.Column):
- intermediate[v.alias_or_name] = k
+ intermediate[v.name] = k
for projection in projections:
for node, *_ in projection.walk():
name = intermediate.get(node)
if name:
node.replace(exp.column(name, step.name))
+
if aggregate.condition:
for node, *_ in aggregate.condition.walk():
name = intermediate.get(node) or intermediate.get(node.name)
@@ -197,6 +204,13 @@ class Step:
order = expression.args.get("order")
if order:
+ if isinstance(step, Aggregate):
+ for i, ordered in enumerate(order.expressions):
+ if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
+ ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
+
+ set_ops_and_aggs(aggregate)
+
sort = Sort()
sort.name = step.name
sort.key = order.expressions
@@ -340,7 +354,10 @@ class Join(Step):
def _to_s(self, indent: str) -> t.List[str]:
lines = []
for name, join in self.joins.items():
- lines.append(f"{indent}{name}: {join['side']}")
+ lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
+ join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
+ if join_key:
+ lines.append(f"{indent}Key: {join_key}")
if join.get("condition"):
lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines