diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/executor/python.py | 190 |
1 files changed, 95 insertions, 95 deletions
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index fcb016b..7d1db32 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -1,15 +1,14 @@ import ast import collections import itertools +import math -from sqlglot import exp, planner +from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.executor.context import Context from sqlglot.executor.env import ENV from sqlglot.executor.table import Table -from sqlglot.generator import Generator from sqlglot.helper import csv_reader -from sqlglot.tokens import Tokenizer class PythonExecutor: @@ -26,7 +25,11 @@ class PythonExecutor: while queue: node = queue.pop() context = self.context( - {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } ) running.add(node) @@ -76,13 +79,10 @@ class PythonExecutor: return Table(expression.alias_or_name for expression in expressions) def scan(self, step, context): - if hasattr(step, "source"): - source = step.source + source = step.source - if isinstance(source, exp.Expression): - source = source.name or source.alias - else: - source = step.name + if isinstance(source, exp.Expression): + source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) @@ -96,14 +96,12 @@ class PythonExecutor: if projections: sink = self.table(step.projections) - elif source in context: - sink = Table(context[source].columns) else: sink = None for reader, ctx in table_iter: if sink is None: - sink = Table(ctx[source].columns) + sink = Table(reader.columns) if condition and not ctx.eval(condition): continue @@ -135,98 +133,79 @@ class PythonExecutor: types.append(type(ast.literal_eval(v))) except (ValueError, SyntaxError): types.append(str) - context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) - yield context[alias], context + context.set_row(tuple(t(v) for t, v in zip(types, row))) + yield context.table.reader, context def join(self, step, context): source = step.name - join_context = self.context({source: context.tables[source]}) - - def merge_context(ctx, table): - # create a new context where all existing tables are mapped to a new one - return self.context({name: table for name in ctx.tables}) + source_table = context.tables[source] + source_context = self.context({source: source_table}) + column_ranges = {source: range(0, len(source_table.columns))} for name, join in step.joins.items(): - join_context = self.context({**join_context.tables, name: context.tables[name]}) + table = context.tables[name] + start = max(r.stop for r in column_ranges.values()) + column_ranges[name] = range(start, len(table.columns) + start) + join_context = self.context({name: table}) if join.get("source_key"): - table = self.hash_join(join, source, name, join_context) + table = self.hash_join(join, source_context, join_context) else: - table = self.nested_loop_join(join, source, name, join_context) + table = self.nested_loop_join(join, source_context, join_context) - join_context = merge_context(join_context, table) - - # apply projections or conditions - context = self.scan(step, join_context) + source_context = self.context( + { + name: Table(table.columns, table.rows, column_range) + for name, column_range in column_ranges.items() + } + ) - # use the scan context since it returns a single table - # otherwise there are no projections so all other tables are still in scope - if step.projections: - return context + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) - return merge_context(join_context, context.tables[source]) + if not condition or not projections: + return source_context - def nested_loop_join(self, _join, a, b, context): - table = Table(context.tables[a].columns + context.tables[b].columns) + sink = self.table(step.projections if projections else source_context.columns) - for reader_a, _ in context.table_iter(a): - for reader_b, _ in context.table_iter(b): - table.append(reader_a.row + reader_b.row) + for reader, ctx in join_context: + if condition and not ctx.eval(condition): + continue - return table + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) - def hash_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) + if len(sink) >= step.limit: + break - results = collections.defaultdict(lambda: ([], [])) + return self.context({step.name: sink}) - for reader, ctx in context.table_iter(a): - results[ctx.eval_tuple(a_key)][0].append(reader.row) - for reader, ctx in context.table_iter(b): - results[ctx.eval_tuple(b_key)][1].append(reader.row) + def nested_loop_join(self, _join, source_context, join_context): + table = Table(source_context.columns + join_context.columns) - table = Table(context.tables[a].columns + context.tables[b].columns) - for a_group, b_group in results.values(): - for a_row, b_row in itertools.product(a_group, b_group): - table.append(a_row + b_row) + for reader_a, _ in source_context: + for reader_b, _ in join_context: + table.append(reader_a.row + reader_b.row) return table - def sort_merge_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) - - context.sort(a, a_key) - context.sort(b, b_key) - - a_i = 0 - b_i = 0 - a_n = len(context.tables[a]) - b_n = len(context.tables[b]) - - table = Table(context.tables[a].columns + context.tables[b].columns) - - def get_key(source, key, i): - context.set_index(source, i) - return context.eval_tuple(key) + def hash_join(self, join, source_context, join_context): + source_key = self.generate_tuple(join["source_key"]) + join_key = self.generate_tuple(join["join_key"]) - while a_i < a_n and b_i < b_n: - key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) - - a_group = [] - - while a_i < a_n and key == get_key(a, a_key, a_i): - a_group.append(context[a].row) - a_i += 1 + results = collections.defaultdict(lambda: ([], [])) - b_group = [] + for reader, ctx in source_context: + results[ctx.eval_tuple(source_key)][0].append(reader.row) + for reader, ctx in join_context: + results[ctx.eval_tuple(join_key)][1].append(reader.row) - while b_i < b_n and key == get_key(b, b_key, b_i): - b_group.append(context[b].row) - b_i += 1 + table = Table(source_context.columns + join_context.columns) + for a_group, b_group in results.values(): for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) @@ -238,16 +217,18 @@ class PythonExecutor: aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) - context.sort(source, group_by) - - if step.operands: + if operands: source_table = context.tables[source] operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) - context = self.context({source: operand_table}) + context = self.context( + {None: operand_table, **{table: operand_table for table in context.tables}} + ) + + context.sort(group_by) group = None start = 0 @@ -256,15 +237,15 @@ class PythonExecutor: table = self.table(step.group + step.aggregations) for i in range(length): - context.set_index(source, i) + context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 if i == length - 1: - context.set_range(source, start, end - 1) + context.set_range(start, end - 1) elif key != group: - context.set_range(source, start, end - 2) + context.set_range(start, end - 2) else: continue @@ -272,13 +253,32 @@ class PythonExecutor: group = key start = end - 2 - return self.scan(step, self.context({source: table})) + context = self.context({step.name: table, **{name: table for name in context.tables}}) + + if step.projections: + return self.scan(step, context) + return context def sort(self, step, context): - table = list(context.tables)[0] - key = self.generate_tuple(step.key) - context.sort(table, key) - return self.scan(step, context) + projections = self.generate_tuple(step.projections) + + sink = self.table(step.projections) + + for reader, ctx in context: + sink.append(ctx.eval_tuple(projections)) + + context = self.context( + { + None: sink, + **{table: sink for table in context.tables}, + } + ) + context.sort(self.generate_tuple(step.key)) + + if not math.isinf(step.limit): + context.table.rows = context.table.rows[0 : step.limit] + + return self.context({step.name: context.table}) def _cast_py(self, expression): @@ -293,7 +293,7 @@ def _cast_py(self, expression): def _column_py(self, expression): - table = self.sql(expression, "table") + table = self.sql(expression, "table") or None this = self.sql(expression, "this") return f"scope[{table}][{this}]" @@ -319,10 +319,10 @@ def _ordered_py(self, expression): class Python(Dialect): - class Tokenizer(Tokenizer): - ESCAPE = "\\" + class Tokenizer(tokens.Tokenizer): + ESCAPES = ["\\"] - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, |