diff options
Diffstat (limited to 'sqlglot/planner.py')
-rw-r--r-- | sqlglot/planner.py | 52 |
1 files changed, 34 insertions, 18 deletions
diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 51db2d4..4967231 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -130,18 +130,20 @@ class Step: aggregations = [] sequence = itertools.count() - for e in expression.expressions: - aggregation = e.find(exp.AggFunc) - - if aggregation: - projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) - aggregations.append(e) - for operand in aggregation.unnest_operands(): + def extract_agg_operands(expression): + for agg in expression.find_all(exp.AggFunc): + for operand in agg.unnest_operands(): if isinstance(operand, exp.Column): continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" operand.replace(exp.column(operands[operand], quoted=True)) + + for e in expression.expressions: + if e.find(exp.AggFunc): + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + aggregations.append(e) + extract_agg_operands(e) else: projections.append(e) @@ -156,6 +158,13 @@ class Step: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name + + having = expression.args.get("having") + + if having: + extract_agg_operands(having) + aggregate.condition = having.this + aggregate.operands = tuple( alias(operand, alias_) for operand, alias_ in operands.items() ) @@ -172,11 +181,6 @@ class Step: aggregate.add_dependency(step) step = aggregate - having = expression.args.get("having") - - if having: - step.condition = having.this - order = expression.args.get("order") if order: @@ -188,6 +192,17 @@ class Step: step.projections = projections + if isinstance(expression, exp.Select) and expression.args.get("distinct"): + distinct = Aggregate() + distinct.source = step.name + distinct.name = step.name + distinct.group = { + e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name) + for e in projections or expression.expressions + } + distinct.add_dependency(step) + step = distinct + limit = expression.args.get("limit") if limit: @@ -231,6 +246,9 @@ class Step: if self.condition: lines.append(f"{nested}Condition: {self.condition.sql()}") + if self.limit is not math.inf: + lines.append(f"{nested}Limit: {self.limit}") + if self.dependencies: lines.append(f"{nested}Dependencies:") for dependency in self.dependencies: @@ -258,12 +276,7 @@ class Scan(Step): cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None ) -> Step: table = expression - alias_ = expression.alias - - if not alias_: - raise UnsupportedError( - "Tables/Subqueries must be aliased. Run it through the optimizer" - ) + alias_ = expression.alias_or_name if isinstance(expression, exp.Subquery): table = expression.this @@ -338,6 +351,9 @@ class Aggregate(Step): lines.append(f"{indent}Group:") for expression in self.group.values(): lines.append(f"{indent} - {expression.sql()}") + if self.condition: + lines.append(f"{indent}Having:") + lines.append(f"{indent} - {self.condition.sql()}") if self.operands: lines.append(f"{indent}Operands:") for expression in self.operands: |