From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- sqlglot/planner.py | 340 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 340 insertions(+) create mode 100644 sqlglot/planner.py (limited to 'sqlglot/planner.py') diff --git a/sqlglot/planner.py b/sqlglot/planner.py new file mode 100644 index 0000000..2006a75 --- /dev/null +++ b/sqlglot/planner.py @@ -0,0 +1,340 @@ +import itertools +import math + +from sqlglot import alias, exp +from sqlglot.errors import UnsupportedError +from sqlglot.optimizer.simplify import simplify + + +class Plan: + def __init__(self, expression): + self.expression = expression + self.root = Step.from_expression(self.expression) + self._dag = {} + + @property + def dag(self): + if not self._dag: + dag = {} + nodes = {self.root} + + while nodes: + node = nodes.pop() + dag[node] = set() + for dep in node.dependencies: + dag[node].add(dep) + nodes.add(dep) + self._dag = dag + + return self._dag + + @property + def leaves(self): + return (node for node, deps in self.dag.items() if not deps) + + +class Step: + @classmethod + def from_expression(cls, expression, ctes=None): + """ + Build a DAG of Steps from a SQL expression. + + Giving an expression like: + + SELECT x.a, SUM(x.b) + FROM x + JOIN y + ON x.a = y.a + GROUP BY x.a + + Transform it into a DAG of the form: + + Aggregate(x.a, SUM(x.b)) + Join(y) + Scan(x) + Scan(y) + + This can then more easily be executed on by an engine. + """ + ctes = ctes or {} + with_ = expression.args.get("with") + + # CTEs break the mold of scope and introduce themselves to all in the context. + if with_: + ctes = ctes.copy() + for cte in with_.expressions: + step = Step.from_expression(cte.this, ctes) + step.name = cte.alias + ctes[step.name] = step + + from_ = expression.args.get("from") + + if from_: + from_ = from_.expressions + if len(from_) > 1: + raise UnsupportedError( + "Multi-from statements are unsupported. Run it through the optimizer" + ) + + step = Scan.from_expression(from_[0], ctes) + else: + raise UnsupportedError("Static selects are unsupported.") + + joins = expression.args.get("joins") + + if joins: + join = Join.from_joins(joins, ctes) + join.name = step.name + join.add_dependency(step) + step = join + + projections = [] # final selects in this chain of steps representing a select + operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) + aggregations = [] + sequence = itertools.count() + + for e in expression.expressions: + aggregation = e.find(exp.AggFunc) + + if aggregation: + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + aggregations.append(e) + for operand in aggregation.unnest_operands(): + if isinstance(operand, exp.Column): + continue + if operand not in operands: + operands[operand] = f"_a_{next(sequence)}" + operand.replace( + exp.column(operands[operand], step.name, quoted=True) + ) + else: + projections.append(e) + + where = expression.args.get("where") + + if where: + step.condition = where.this + + group = expression.args.get("group") + + if group: + aggregate = Aggregate() + aggregate.source = step.name + aggregate.name = step.name + aggregate.operands = tuple( + alias(operand, alias_) for operand, alias_ in operands.items() + ) + aggregate.aggregations = aggregations + aggregate.group = [ + exp.column(e.alias_or_name, step.name, quoted=True) + for e in group.expressions + ] + aggregate.add_dependency(step) + step = aggregate + + having = expression.args.get("having") + + if having: + step.condition = having.this + + order = expression.args.get("order") + + if order: + sort = Sort() + sort.name = step.name + sort.key = order.expressions + sort.add_dependency(step) + step = sort + for k in sort.key + projections: + for column in k.find_all(exp.Column): + column.set("table", exp.to_identifier(step.name, quoted=True)) + + step.projections = projections + + limit = expression.args.get("limit") + + if limit: + step.limit = int(limit.text("expression")) + + return step + + def __init__(self): + self.name = None + self.dependencies = set() + self.dependents = set() + self.projections = [] + self.limit = math.inf + self.condition = None + + def add_dependency(self, dependency): + self.dependencies.add(dependency) + dependency.dependents.add(self) + + def __repr__(self): + return self.to_s() + + def to_s(self, level=0): + indent = " " * level + nested = f"{indent} " + + context = self._to_s(f"{nested} ") + + if context: + context = [f"{nested}Context:"] + context + + lines = [ + f"{indent}- {self.__class__.__name__}: {self.name}", + *context, + f"{nested}Projections:", + ] + + for expression in self.projections: + lines.append(f"{nested} - {expression.sql()}") + + if self.condition: + lines.append(f"{nested}Condition: {self.condition.sql()}") + + if self.dependencies: + lines.append(f"{nested}Dependencies:") + for dependency in self.dependencies: + lines.append(" " + dependency.to_s(level + 1)) + + return "\n".join(lines) + + def _to_s(self, _indent): + return [] + + +class Scan(Step): + @classmethod + def from_expression(cls, expression, ctes=None): + table = expression.this + alias_ = expression.alias + + if not alias_: + raise UnsupportedError( + "Tables/Subqueries must be aliased. Run it through the optimizer" + ) + + if isinstance(expression, exp.Subquery): + step = Step.from_expression(table, ctes) + step.name = alias_ + return step + + step = Scan() + step.name = alias_ + step.source = expression + if table.name in ctes: + step.add_dependency(ctes[table.name]) + + return step + + def __init__(self): + super().__init__() + self.source = None + + def _to_s(self, indent): + return [f"{indent}Source: {self.source.sql()}"] + + +class Write(Step): + pass + + +class Join(Step): + @classmethod + def from_joins(cls, joins, ctes=None): + step = Join() + + for join in joins: + name = join.this.alias + on = join.args.get("on") or exp.TRUE + source_key = [] + join_key = [] + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + for condition in on.flatten() if isinstance(on, exp.And) else [on]: + if isinstance(condition, exp.EQ): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.TRUE) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.TRUE) + + on = simplify(on) + + step.joins[name] = { + "side": join.side, + "join_key": join_key, + "source_key": source_key, + "condition": None if on == exp.TRUE else on, + } + + step.add_dependency(Scan.from_expression(join.this, ctes)) + + return step + + def __init__(self): + super().__init__() + self.joins = {} + + def _to_s(self, indent): + lines = [] + for name, join in self.joins.items(): + lines.append(f"{indent}{name}: {join['side']}") + if join.get("condition"): + lines.append(f"{indent}On: {join['condition'].sql()}") + return lines + + +class Aggregate(Step): + def __init__(self): + super().__init__() + self.aggregations = [] + self.operands = [] + self.group = [] + self.source = None + + def _to_s(self, indent): + lines = [f"{indent}Aggregations:"] + + for expression in self.aggregations: + lines.append(f"{indent} - {expression.sql()}") + + if self.group: + lines.append(f"{indent}Group:") + for expression in self.group: + lines.append(f"{indent} - {expression.sql()}") + if self.operands: + lines.append(f"{indent}Operands:") + for expression in self.operands: + lines.append(f"{indent} - {expression.sql()}") + + return lines + + +class Sort(Step): + def __init__(self): + super().__init__() + self.key = None + + def _to_s(self, indent): + lines = [f"{indent}Key:"] + + for expression in self.key: + lines.append(f"{indent} - {expression.sql()}") + + return lines -- cgit v1.2.3