diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:39 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:39 +0000 |
commit | f2981e8e4d28233864f1ca06ecec45ab80bf9eae (patch) | |
tree | b70cb633916830138ce3424aa361f0bbaff02be2 /sqlglot/executor/python.py | |
parent | Releasing debian version 10.0.1-1. (diff) | |
download | sqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.tar.xz sqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.zip |
Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/executor/python.py')
-rw-r--r-- | sqlglot/executor/python.py | 287 |
1 files changed, 172 insertions, 115 deletions
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 7d1db32..cb2543c 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -5,16 +5,18 @@ import math from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.errors import ExecuteError from sqlglot.executor.context import Context from sqlglot.executor.env import ENV -from sqlglot.executor.table import Table -from sqlglot.helper import csv_reader +from sqlglot.executor.table import RowReader, Table +from sqlglot.helper import csv_reader, subclasses class PythonExecutor: - def __init__(self, env=None): - self.generator = Python().generator(identify=True) + def __init__(self, env=None, tables=None): + self.generator = Python().generator(identify=True, comments=False) self.env = {**ENV, **(env or {})} + self.tables = tables or {} def execute(self, plan): running = set() @@ -24,36 +26,41 @@ class PythonExecutor: while queue: node = queue.pop() - context = self.context( - { - name: table - for dep in node.dependencies - for name, table in contexts[dep].tables.items() - } - ) - running.add(node) - - if isinstance(node, planner.Scan): - contexts[node] = self.scan(node, context) - elif isinstance(node, planner.Aggregate): - contexts[node] = self.aggregate(node, context) - elif isinstance(node, planner.Join): - contexts[node] = self.join(node, context) - elif isinstance(node, planner.Sort): - contexts[node] = self.sort(node, context) - else: - raise NotImplementedError - - running.remove(node) - finished.add(node) - - for dep in node.dependents: - if dep not in running and all(d in contexts for d in dep.dependencies): - queue.add(dep) - - for dep in node.dependencies: - if all(d in finished for d in dep.dependents): - contexts.pop(dep) + try: + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + running.add(node) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + elif isinstance(node, planner.SetOperation): + contexts[node] = self.set_operation(node, context) + else: + raise NotImplementedError + + running.remove(node) + finished.add(node) + + for dep in node.dependents: + if dep not in running and all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + except Exception as e: + raise ExecuteError(f"Step '{node.id}' failed: {e}") from e root = plan.root return contexts[root].tables[root.name] @@ -76,38 +83,43 @@ class PythonExecutor: return Context(tables, env=self.env) def table(self, expressions): - return Table(expression.alias_or_name for expression in expressions) + return Table( + expression.alias_or_name if isinstance(expression, exp.Expression) else expression + for expression in expressions + ) def scan(self, step, context): source = step.source - if isinstance(source, exp.Expression): + if source and isinstance(source, exp.Expression): source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) - if source in context: + if source is None: + context, table_iter = self.static() + elif source in context: if not projections and not condition: return self.context({step.name: context.tables[source]}) table_iter = context.table_iter(source) - else: + elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV): table_iter = self.scan_csv(step) + context = next(table_iter) + else: + context, table_iter = self.scan_table(step) if projections: sink = self.table(step.projections) else: - sink = None - - for reader, ctx in table_iter: - if sink is None: - sink = Table(reader.columns) + sink = self.table(context.columns) - if condition and not ctx.eval(condition): + for reader in table_iter: + if condition and not context.eval(condition): continue if projections: - sink.append(ctx.eval_tuple(projections)) + sink.append(context.eval_tuple(projections)) else: sink.append(reader.row) @@ -116,14 +128,23 @@ class PythonExecutor: return self.context({step.name: sink}) + def static(self): + return self.context({}), [RowReader(())] + + def scan_table(self, step): + table = self.tables.find(step.source) + context = self.context({step.source.alias_or_name: table}) + return context, iter(table) + def scan_csv(self, step): - source = step.source - alias = source.alias + alias = step.source.alias + source = step.source.this with csv_reader(source) as reader: columns = next(reader) table = Table(columns) context = self.context({alias: table}) + yield context types = [] for row in reader: @@ -134,7 +155,7 @@ class PythonExecutor: except (ValueError, SyntaxError): types.append(str) context.set_row(tuple(t(v) for t, v in zip(types, row))) - yield context.table.reader, context + yield context.table.reader def join(self, step, context): source = step.name @@ -160,16 +181,19 @@ class PythonExecutor: for name, column_range in column_ranges.items() } ) + condition = self.generate(join["condition"]) + if condition: + source_context.filter(condition) condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) - if not condition or not projections: + if not condition and not projections: return source_context sink = self.table(step.projections if projections else source_context.columns) - for reader, ctx in join_context: + for reader, ctx in source_context: if condition and not ctx.eval(condition): continue @@ -181,7 +205,15 @@ class PythonExecutor: if len(sink) >= step.limit: break - return self.context({step.name: sink}) + if projections: + return self.context({step.name: sink}) + else: + return self.context( + { + name: Table(table.columns, sink.rows, table.column_range) + for name, table in source_context.tables.items() + } + ) def nested_loop_join(self, _join, source_context, join_context): table = Table(source_context.columns + join_context.columns) @@ -195,6 +227,8 @@ class PythonExecutor: def hash_join(self, join, source_context, join_context): source_key = self.generate_tuple(join["source_key"]) join_key = self.generate_tuple(join["join_key"]) + left = join.get("side") == "LEFT" + right = join.get("side") == "RIGHT" results = collections.defaultdict(lambda: ([], [])) @@ -204,28 +238,47 @@ class PythonExecutor: results[ctx.eval_tuple(join_key)][1].append(reader.row) table = Table(source_context.columns + join_context.columns) + nulls = [(None,) * len(join_context.columns if left else source_context.columns)] for a_group, b_group in results.values(): + if left: + b_group = b_group or nulls + elif right: + a_group = a_group or nulls + for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) return table def aggregate(self, step, context): - source = step.source - group_by = self.generate_tuple(step.group) + group_by = self.generate_tuple(step.group.values()) aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) if operands: - source_table = context.tables[source] - operand_table = Table(source_table.columns + self.table(step.operands).columns) + operand_table = Table(self.table(step.operands).columns) for reader, ctx in context: - operand_table.append(reader.row + ctx.eval_tuple(operands)) + operand_table.append(ctx.eval_tuple(operands)) + + for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)): + context.table.rows[i] = a + b + + width = len(context.columns) + context.add_columns(*operand_table.columns) + + operand_table = Table( + context.columns, + context.table.rows, + range(width, width + len(operand_table.columns)), + ) context = self.context( - {None: operand_table, **{table: operand_table for table in context.tables}} + { + None: operand_table, + **context.tables, + } ) context.sort(group_by) @@ -233,25 +286,22 @@ class PythonExecutor: group = None start = 0 end = 1 - length = len(context.tables[source]) - table = self.table(step.group + step.aggregations) + length = len(context.table) + table = self.table(list(step.group) + step.aggregations) for i in range(length): context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 - + if key != group: + context.set_range(start, end - 2) + table.append(group + context.eval_tuple(aggregations)) + group = key + start = end - 2 if i == length - 1: context.set_range(start, end - 1) - elif key != group: - context.set_range(start, end - 2) - else: - continue - - table.append(group + context.eval_tuple(aggregations)) - group = key - start = end - 2 + table.append(group + context.eval_tuple(aggregations)) context = self.context({step.name: table, **{name: table for name in context.tables}}) @@ -262,60 +312,77 @@ class PythonExecutor: def sort(self, step, context): projections = self.generate_tuple(step.projections) - sink = self.table(step.projections) + projection_columns = [p.alias_or_name for p in step.projections] + all_columns = list(context.columns) + projection_columns + sink = self.table(all_columns) for reader, ctx in context: - sink.append(ctx.eval_tuple(projections)) + sink.append(reader.row + ctx.eval_tuple(projections)) - context = self.context( + sort_ctx = self.context( { None: sink, **{table: sink for table in context.tables}, } ) - context.sort(self.generate_tuple(step.key)) + sort_ctx.sort(self.generate_tuple(step.key)) if not math.isinf(step.limit): - context.table.rows = context.table.rows[0 : step.limit] + sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit] - return self.context({step.name: context.table}) + output = Table( + projection_columns, + rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows], + ) + return self.context({step.name: output}) + def set_operation(self, step, context): + left = context.tables[step.left] + right = context.tables[step.right] -def _cast_py(self, expression): - to = expression.args["to"].this - this = self.sql(expression, "this") + sink = self.table(left.columns) + + if issubclass(step.op, exp.Intersect): + sink.rows = list(set(left.rows).intersection(set(right.rows))) + elif issubclass(step.op, exp.Except): + sink.rows = list(set(left.rows).difference(set(right.rows))) + elif issubclass(step.op, exp.Union) and step.distinct: + sink.rows = list(set(left.rows).union(set(right.rows))) + else: + sink.rows = left.rows + right.rows - if to == exp.DataType.Type.DATE: - return f"datetime.date.fromisoformat({this})" - if to == exp.DataType.Type.TEXT: - return f"str({this})" - raise NotImplementedError + return self.context({step.name: sink}) -def _column_py(self, expression): - table = self.sql(expression, "table") or None +def _ordered_py(self, expression): this = self.sql(expression, "this") - return f"scope[{table}][{this}]" + desc = "True" if expression.args.get("desc") else "False" + nulls_first = "True" if expression.args.get("nulls_first") else "False" + return f"ORDERED({this}, {desc}, {nulls_first})" -def _interval_py(self, expression): - this = self.sql(expression, "this") - unit = expression.text("unit").upper() - if unit == "DAY": - return f"datetime.timedelta(days=float({this}))" - raise NotImplementedError +def _rename(self, e): + try: + if "expressions" in e.args: + this = self.sql(e, "this") + this = f"{this}, " if this else "" + return f"{e.key.upper()}({this}{self.expressions(e)})" + return f"{e.key.upper()}({self.format_args(*e.args.values())})" + except Exception as ex: + raise Exception(f"Could not rename {repr(e)}") from ex -def _like_py(self, expression): +def _case_sql(self, expression): this = self.sql(expression, "this") - expression = self.sql(expression, "expression") - return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))""" + chain = self.sql(expression, "default") or "None" + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" -def _ordered_py(self, expression): - this = self.sql(expression, "this") - desc = expression.args.get("desc") - return f"desc({this})" if desc else this + return chain class Python(Dialect): @@ -324,32 +391,22 @@ class Python(Dialect): class Generator(generator.Generator): TRANSFORMS = { + **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)}, + **{klass: _rename for klass in exp.ALL_FUNCTIONS}, + exp.Case: _case_sql, exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, exp.And: lambda self, e: self.binary(e, "and"), + exp.Between: _rename, exp.Boolean: lambda self, e: "True" if e.this else "False", - exp.Cast: _cast_py, - exp.Column: _column_py, - exp.EQ: lambda self, e: self.binary(e, "=="), + exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", + exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", - exp.Interval: _interval_py, exp.Is: lambda self, e: self.binary(e, "is"), - exp.Like: _like_py, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", exp.Or: lambda self, e: self.binary(e, "or"), exp.Ordered: _ordered_py, exp.Star: lambda *_: "1", } - - def case_sql(self, expression): - this = self.sql(expression, "this") - chain = self.sql(expression, "default") or "None" - - for e in reversed(expression.args["ifs"]): - true = self.sql(e, "true") - condition = self.sql(e, "this") - condition = f"{this} = ({condition})" if this else condition - chain = f"{true} if {condition} else ({chain})" - - return chain |