summaryrefslogtreecommitdiffstats
path: root/sqlglot/planner.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/planner.py')
-rw-r--r--sqlglot/planner.py52
1 files changed, 34 insertions, 18 deletions
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 51db2d4..4967231 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -130,18 +130,20 @@ class Step:
aggregations = []
sequence = itertools.count()
- for e in expression.expressions:
- aggregation = e.find(exp.AggFunc)
-
- if aggregation:
- projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
- aggregations.append(e)
- for operand in aggregation.unnest_operands():
+ def extract_agg_operands(expression):
+ for agg in expression.find_all(exp.AggFunc):
+ for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], quoted=True))
+
+ for e in expression.expressions:
+ if e.find(exp.AggFunc):
+ projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
+ aggregations.append(e)
+ extract_agg_operands(e)
else:
projections.append(e)
@@ -156,6 +158,13 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
+
+ having = expression.args.get("having")
+
+ if having:
+ extract_agg_operands(having)
+ aggregate.condition = having.this
+
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
@@ -172,11 +181,6 @@ class Step:
aggregate.add_dependency(step)
step = aggregate
- having = expression.args.get("having")
-
- if having:
- step.condition = having.this
-
order = expression.args.get("order")
if order:
@@ -188,6 +192,17 @@ class Step:
step.projections = projections
+ if isinstance(expression, exp.Select) and expression.args.get("distinct"):
+ distinct = Aggregate()
+ distinct.source = step.name
+ distinct.name = step.name
+ distinct.group = {
+ e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
+ for e in projections or expression.expressions
+ }
+ distinct.add_dependency(step)
+ step = distinct
+
limit = expression.args.get("limit")
if limit:
@@ -231,6 +246,9 @@ class Step:
if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}")
+ if self.limit is not math.inf:
+ lines.append(f"{nested}Limit: {self.limit}")
+
if self.dependencies:
lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies:
@@ -258,12 +276,7 @@ class Scan(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression
- alias_ = expression.alias
-
- if not alias_:
- raise UnsupportedError(
- "Tables/Subqueries must be aliased. Run it through the optimizer"
- )
+ alias_ = expression.alias_or_name
if isinstance(expression, exp.Subquery):
table = expression.this
@@ -338,6 +351,9 @@ class Aggregate(Step):
lines.append(f"{indent}Group:")
for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
+ if self.condition:
+ lines.append(f"{indent}Having:")
+ lines.append(f"{indent} - {self.condition.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands: