summaryrefslogtreecommitdiffstats
path: root/sqlglot/planner.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/planner.py')
-rw-r--r--sqlglot/planner.py227
1 files changed, 162 insertions, 65 deletions
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index cd1de5e..51db2d4 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import itertools
import math
+import typing as t
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
@@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition
class Plan:
- def __init__(self, expression):
- self.expression = expression
+ def __init__(self, expression: exp.Expression) -> None:
+ self.expression = expression.copy()
self.root = Step.from_expression(self.expression)
- self._dag = {}
+ self._dag: t.Dict[Step, t.Set[Step]] = {}
@property
- def dag(self):
+ def dag(self) -> t.Dict[Step, t.Set[Step]]:
if not self._dag:
- dag = {}
+ dag: t.Dict[Step, t.Set[Step]] = {}
nodes = {self.root}
while nodes:
@@ -29,32 +32,64 @@ class Plan:
return self._dag
@property
- def leaves(self):
+ def leaves(self) -> t.Generator[Step, None, None]:
return (node for node, deps in self.dag.items() if not deps)
+ def __repr__(self) -> str:
+ return f"Plan\n----\n{repr(self.root)}"
+
class Step:
@classmethod
- def from_expression(cls, expression, ctes=None):
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
"""
- Build a DAG of Steps from a SQL expression.
-
- Giving an expression like:
-
- SELECT x.a, SUM(x.b)
- FROM x
- JOIN y
- ON x.a = y.a
+ Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
+ Note: the expression's tables and subqueries must be aliased for this method to work. For
+ example, given the following expression:
+
+ SELECT
+ x.a,
+ SUM(x.b)
+ FROM x AS x
+ JOIN y AS y
+ ON x.a = y.a
GROUP BY x.a
- Transform it into a DAG of the form:
-
- Aggregate(x.a, SUM(x.b))
- Join(y)
- Scan(x)
- Scan(y)
-
- This can then more easily be executed on by an engine.
+ the following DAG is produced (the expression IDs might differ per execution):
+
+ - Aggregate: x (4347984624)
+ Context:
+ Aggregations:
+ - SUM(x.b)
+ Group:
+ - x.a
+ Projections:
+ - x.a
+ - "x".""
+ Dependencies:
+ - Join: x (4347985296)
+ Context:
+ y:
+ On: x.a = y.a
+ Projections:
+ Dependencies:
+ - Scan: x (4347983136)
+ Context:
+ Source: x AS x
+ Projections:
+ - Scan: y (4343416624)
+ Context:
+ Source: y AS y
+ Projections:
+
+ Args:
+ expression: the expression to build the DAG from.
+ ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
+
+ Returns:
+ A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
with_ = expression.args.get("with")
@@ -65,11 +100,11 @@ class Step:
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
step.name = cte.alias
- ctes[step.name] = step
+ ctes[step.name] = step # type: ignore
from_ = expression.args.get("from")
- if from_:
+ if isinstance(expression, exp.Select) and from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError(
@@ -77,8 +112,10 @@ class Step:
)
step = Scan.from_expression(from_[0], ctes)
+ elif isinstance(expression, exp.Union):
+ step = SetOperation.from_expression(expression, ctes)
else:
- raise UnsupportedError("Static selects are unsupported.")
+ step = Scan()
joins = expression.args.get("joins")
@@ -115,7 +152,7 @@ class Step:
group = expression.args.get("group")
- if group:
+ if group or aggregations:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
@@ -123,7 +160,15 @@ class Step:
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
- aggregate.group = group.expressions
+ # give aggregates names and replace projections with references to them
+ aggregate.group = {
+ f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
+ }
+ for projection in projections:
+ for i, e in aggregate.group.items():
+ for child, _, _ in projection.walk():
+ if child == e:
+ child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)
step = aggregate
@@ -150,22 +195,22 @@ class Step:
return step
- def __init__(self):
- self.name = None
- self.dependencies = set()
- self.dependents = set()
- self.projections = []
- self.limit = math.inf
- self.condition = None
+ def __init__(self) -> None:
+ self.name: t.Optional[str] = None
+ self.dependencies: t.Set[Step] = set()
+ self.dependents: t.Set[Step] = set()
+ self.projections: t.Sequence[exp.Expression] = []
+ self.limit: float = math.inf
+ self.condition: t.Optional[exp.Expression] = None
- def add_dependency(self, dependency):
+ def add_dependency(self, dependency: Step) -> None:
self.dependencies.add(dependency)
dependency.dependents.add(self)
- def __repr__(self):
+ def __repr__(self) -> str:
return self.to_s()
- def to_s(self, level=0):
+ def to_s(self, level: int = 0) -> str:
indent = " " * level
nested = f"{indent} "
@@ -175,7 +220,7 @@ class Step:
context = [f"{nested}Context:"] + context
lines = [
- f"{indent}- {self.__class__.__name__}: {self.name}",
+ f"{indent}- {self.id}",
*context,
f"{nested}Projections:",
]
@@ -193,13 +238,25 @@ class Step:
return "\n".join(lines)
- def _to_s(self, _indent):
+ @property
+ def type_name(self) -> str:
+ return self.__class__.__name__
+
+ @property
+ def id(self) -> str:
+ name = self.name
+ name = f" {name}" if name else ""
+ return f"{self.type_name}:{name} ({id(self)})"
+
+ def _to_s(self, _indent: str) -> t.List[str]:
return []
class Scan(Step):
@classmethod
- def from_expression(cls, expression, ctes=None):
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
table = expression
alias_ = expression.alias
@@ -217,26 +274,24 @@ class Scan(Step):
step = Scan()
step.name = alias_
step.source = expression
- if table.name in ctes:
+ if ctes and table.name in ctes:
step.add_dependency(ctes[table.name])
return step
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.source = None
-
- def _to_s(self, indent):
- return [f"{indent}Source: {self.source.sql()}"]
+ self.source: t.Optional[exp.Expression] = None
-
-class Write(Step):
- pass
+ def _to_s(self, indent: str) -> t.List[str]:
+ return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore
class Join(Step):
@classmethod
- def from_joins(cls, joins, ctes=None):
+ def from_joins(
+ cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
step = Join()
for join in joins:
@@ -252,28 +307,28 @@ class Join(Step):
return step
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.joins = {}
+ self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = []
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"):
- lines.append(f"{indent}On: {join['condition'].sql()}")
+ lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines
class Aggregate(Step):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.aggregations = []
- self.operands = []
- self.group = []
- self.source = None
+ self.aggregations: t.List[exp.Expression] = []
+ self.operands: t.Tuple[exp.Expression, ...] = ()
+ self.group: t.Dict[str, exp.Expression] = {}
+ self.source: t.Optional[str] = None
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Aggregations:"]
for expression in self.aggregations:
@@ -281,7 +336,7 @@ class Aggregate(Step):
if self.group:
lines.append(f"{indent}Group:")
- for expression in self.group:
+ for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
@@ -292,14 +347,56 @@ class Aggregate(Step):
class Sort(Step):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self.key = None
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Key:"]
- for expression in self.key:
+ for expression in self.key: # type: ignore
lines.append(f"{indent} - {expression.sql()}")
return lines
+
+
+class SetOperation(Step):
+ def __init__(
+ self,
+ op: t.Type[exp.Expression],
+ left: str | None,
+ right: str | None,
+ distinct: bool = False,
+ ) -> None:
+ super().__init__()
+ self.op = op
+ self.left = left
+ self.right = right
+ self.distinct = distinct
+
+ @classmethod
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
+ assert isinstance(expression, exp.Union)
+ left = Step.from_expression(expression.left, ctes)
+ right = Step.from_expression(expression.right, ctes)
+ step = cls(
+ op=expression.__class__,
+ left=left.name,
+ right=right.name,
+ distinct=expression.args.get("distinct"),
+ )
+ step.add_dependency(left)
+ step.add_dependency(right)
+ return step
+
+ def _to_s(self, indent: str) -> t.List[str]:
+ lines = []
+ if self.distinct:
+ lines.append(f"{indent}Distinct: {self.distinct}")
+ return lines
+
+ @property
+ def type_name(self) -> str:
+ return self.op.__name__