From f2981e8e4d28233864f1ca06ecec45ab80bf9eae Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 19 Nov 2022 15:50:39 +0100 Subject: Merging upstream version 10.0.8. Signed-off-by: Daniel Baumann --- sqlglot/planner.py | 227 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 162 insertions(+), 65 deletions(-) (limited to 'sqlglot/planner.py') diff --git a/sqlglot/planner.py b/sqlglot/planner.py index cd1de5e..51db2d4 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import itertools import math +import typing as t from sqlglot import alias, exp from sqlglot.errors import UnsupportedError @@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition class Plan: - def __init__(self, expression): - self.expression = expression + def __init__(self, expression: exp.Expression) -> None: + self.expression = expression.copy() self.root = Step.from_expression(self.expression) - self._dag = {} + self._dag: t.Dict[Step, t.Set[Step]] = {} @property - def dag(self): + def dag(self) -> t.Dict[Step, t.Set[Step]]: if not self._dag: - dag = {} + dag: t.Dict[Step, t.Set[Step]] = {} nodes = {self.root} while nodes: @@ -29,32 +32,64 @@ class Plan: return self._dag @property - def leaves(self): + def leaves(self) -> t.Generator[Step, None, None]: return (node for node, deps in self.dag.items() if not deps) + def __repr__(self) -> str: + return f"Plan\n----\n{repr(self.root)}" + class Step: @classmethod - def from_expression(cls, expression, ctes=None): + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: """ - Build a DAG of Steps from a SQL expression. - - Giving an expression like: - - SELECT x.a, SUM(x.b) - FROM x - JOIN y - ON x.a = y.a + Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. + Note: the expression's tables and subqueries must be aliased for this method to work. For + example, given the following expression: + + SELECT + x.a, + SUM(x.b) + FROM x AS x + JOIN y AS y + ON x.a = y.a GROUP BY x.a - Transform it into a DAG of the form: - - Aggregate(x.a, SUM(x.b)) - Join(y) - Scan(x) - Scan(y) - - This can then more easily be executed on by an engine. + the following DAG is produced (the expression IDs might differ per execution): + + - Aggregate: x (4347984624) + Context: + Aggregations: + - SUM(x.b) + Group: + - x.a + Projections: + - x.a + - "x"."" + Dependencies: + - Join: x (4347985296) + Context: + y: + On: x.a = y.a + Projections: + Dependencies: + - Scan: x (4347983136) + Context: + Source: x AS x + Projections: + - Scan: y (4343416624) + Context: + Source: y AS y + Projections: + + Args: + expression: the expression to build the DAG from. + ctes: a dictionary that maps CTEs to their corresponding Step DAG by name. + + Returns: + A Step DAG corresponding to `expression`. """ ctes = ctes or {} with_ = expression.args.get("with") @@ -65,11 +100,11 @@ class Step: for cte in with_.expressions: step = Step.from_expression(cte.this, ctes) step.name = cte.alias - ctes[step.name] = step + ctes[step.name] = step # type: ignore from_ = expression.args.get("from") - if from_: + if isinstance(expression, exp.Select) and from_: from_ = from_.expressions if len(from_) > 1: raise UnsupportedError( @@ -77,8 +112,10 @@ class Step: ) step = Scan.from_expression(from_[0], ctes) + elif isinstance(expression, exp.Union): + step = SetOperation.from_expression(expression, ctes) else: - raise UnsupportedError("Static selects are unsupported.") + step = Scan() joins = expression.args.get("joins") @@ -115,7 +152,7 @@ class Step: group = expression.args.get("group") - if group: + if group or aggregations: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name @@ -123,7 +160,15 @@ class Step: alias(operand, alias_) for operand, alias_ in operands.items() ) aggregate.aggregations = aggregations - aggregate.group = group.expressions + # give aggregates names and replace projections with references to them + aggregate.group = { + f"_g{i}": e for i, e in enumerate(group.expressions if group else []) + } + for projection in projections: + for i, e in aggregate.group.items(): + for child, _, _ in projection.walk(): + if child == e: + child.replace(exp.column(i, step.name)) aggregate.add_dependency(step) step = aggregate @@ -150,22 +195,22 @@ class Step: return step - def __init__(self): - self.name = None - self.dependencies = set() - self.dependents = set() - self.projections = [] - self.limit = math.inf - self.condition = None + def __init__(self) -> None: + self.name: t.Optional[str] = None + self.dependencies: t.Set[Step] = set() + self.dependents: t.Set[Step] = set() + self.projections: t.Sequence[exp.Expression] = [] + self.limit: float = math.inf + self.condition: t.Optional[exp.Expression] = None - def add_dependency(self, dependency): + def add_dependency(self, dependency: Step) -> None: self.dependencies.add(dependency) dependency.dependents.add(self) - def __repr__(self): + def __repr__(self) -> str: return self.to_s() - def to_s(self, level=0): + def to_s(self, level: int = 0) -> str: indent = " " * level nested = f"{indent} " @@ -175,7 +220,7 @@ class Step: context = [f"{nested}Context:"] + context lines = [ - f"{indent}- {self.__class__.__name__}: {self.name}", + f"{indent}- {self.id}", *context, f"{nested}Projections:", ] @@ -193,13 +238,25 @@ class Step: return "\n".join(lines) - def _to_s(self, _indent): + @property + def type_name(self) -> str: + return self.__class__.__name__ + + @property + def id(self) -> str: + name = self.name + name = f" {name}" if name else "" + return f"{self.type_name}:{name} ({id(self)})" + + def _to_s(self, _indent: str) -> t.List[str]: return [] class Scan(Step): @classmethod - def from_expression(cls, expression, ctes=None): + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: table = expression alias_ = expression.alias @@ -217,26 +274,24 @@ class Scan(Step): step = Scan() step.name = alias_ step.source = expression - if table.name in ctes: + if ctes and table.name in ctes: step.add_dependency(ctes[table.name]) return step - def __init__(self): + def __init__(self) -> None: super().__init__() - self.source = None - - def _to_s(self, indent): - return [f"{indent}Source: {self.source.sql()}"] + self.source: t.Optional[exp.Expression] = None - -class Write(Step): - pass + def _to_s(self, indent: str) -> t.List[str]: + return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore class Join(Step): @classmethod - def from_joins(cls, joins, ctes=None): + def from_joins( + cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: step = Join() for join in joins: @@ -252,28 +307,28 @@ class Join(Step): return step - def __init__(self): + def __init__(self) -> None: super().__init__() - self.joins = {} + self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {} - def _to_s(self, indent): + def _to_s(self, indent: str) -> t.List[str]: lines = [] for name, join in self.joins.items(): lines.append(f"{indent}{name}: {join['side']}") if join.get("condition"): - lines.append(f"{indent}On: {join['condition'].sql()}") + lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore return lines class Aggregate(Step): - def __init__(self): + def __init__(self) -> None: super().__init__() - self.aggregations = [] - self.operands = [] - self.group = [] - self.source = None + self.aggregations: t.List[exp.Expression] = [] + self.operands: t.Tuple[exp.Expression, ...] = () + self.group: t.Dict[str, exp.Expression] = {} + self.source: t.Optional[str] = None - def _to_s(self, indent): + def _to_s(self, indent: str) -> t.List[str]: lines = [f"{indent}Aggregations:"] for expression in self.aggregations: @@ -281,7 +336,7 @@ class Aggregate(Step): if self.group: lines.append(f"{indent}Group:") - for expression in self.group: + for expression in self.group.values(): lines.append(f"{indent} - {expression.sql()}") if self.operands: lines.append(f"{indent}Operands:") @@ -292,14 +347,56 @@ class Aggregate(Step): class Sort(Step): - def __init__(self): + def __init__(self) -> None: super().__init__() self.key = None - def _to_s(self, indent): + def _to_s(self, indent: str) -> t.List[str]: lines = [f"{indent}Key:"] - for expression in self.key: + for expression in self.key: # type: ignore lines.append(f"{indent} - {expression.sql()}") return lines + + +class SetOperation(Step): + def __init__( + self, + op: t.Type[exp.Expression], + left: str | None, + right: str | None, + distinct: bool = False, + ) -> None: + super().__init__() + self.op = op + self.left = left + self.right = right + self.distinct = distinct + + @classmethod + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: + assert isinstance(expression, exp.Union) + left = Step.from_expression(expression.left, ctes) + right = Step.from_expression(expression.right, ctes) + step = cls( + op=expression.__class__, + left=left.name, + right=right.name, + distinct=expression.args.get("distinct"), + ) + step.add_dependency(left) + step.add_dependency(right) + return step + + def _to_s(self, indent: str) -> t.List[str]: + lines = [] + if self.distinct: + lines.append(f"{indent}Distinct: {self.distinct}") + return lines + + @property + def type_name(self) -> str: + return self.op.__name__ -- cgit v1.2.3