diff options
Diffstat (limited to 'sqlglot/executor')
-rw-r--r-- | sqlglot/executor/context.py | 44 | ||||
-rw-r--r-- | sqlglot/executor/env.py | 4 | ||||
-rw-r--r-- | sqlglot/executor/python.py | 190 | ||||
-rw-r--r-- | sqlglot/executor/table.py | 27 |
4 files changed, 149 insertions, 116 deletions
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index d265a2c..393347b 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,6 +19,7 @@ class Context: env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables + self._table = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -29,8 +30,27 @@ class Context: def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) + @property + def table(self): + if self._table is None: + self._table = list(self.tables.values())[0] + for other in self.tables.values(): + if self._table.columns != other.columns: + raise Exception(f"Columns are different.") + if len(self._table.rows) != len(other.rows): + raise Exception(f"Rows are different.") + return self._table + + @property + def columns(self): + return self.table.columns + def __iter__(self): - return self.table_iter(list(self.tables)[0]) + self.env["scope"] = self.row_readers + for i in range(len(self.table.rows)): + for table in self.tables.values(): + reader = table[i] + yield reader, self def table_iter(self, table): self.env["scope"] = self.row_readers @@ -38,8 +58,8 @@ class Context: for reader in self.tables[table]: yield reader, self - def sort(self, table, key): - table = self.tables[table] + def sort(self, key): + table = self.table def sort_key(row): table.reader.row = row @@ -47,20 +67,20 @@ class Context: table.rows.sort(key=sort_key) - def set_row(self, table, row): - self.row_readers[table].row = row + def set_row(self, row): + for table in self.tables.values(): + table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, table, index): - self.row_readers[table].row = self.tables[table].rows[index] + def set_index(self, index): + for table in self.tables.values(): + table[index] self.env["scope"] = self.row_readers - def set_range(self, table, start, end): - self.range_readers[table].range = range(start, end) + def set_range(self, start, end): + for name in self.tables: + self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __getitem__(self, table): - return self.env["scope"][table] - def __contains__(self, table): return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 9c49dd1..bbe6c81 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -2,6 +2,8 @@ import datetime import re import statistics +from sqlglot.helper import PYTHON_VERSION + class reverse_key: def __init__(self, obj): @@ -25,7 +27,7 @@ ENV = { "str": str, "desc": reverse_key, "SUM": sum, - "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, + "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore "COUNT": lambda acc: sum(1 for e in acc if e is not None), "MAX": max, "MIN": min, diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index fcb016b..7d1db32 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -1,15 +1,14 @@ import ast import collections import itertools +import math -from sqlglot import exp, planner +from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.executor.context import Context from sqlglot.executor.env import ENV from sqlglot.executor.table import Table -from sqlglot.generator import Generator from sqlglot.helper import csv_reader -from sqlglot.tokens import Tokenizer class PythonExecutor: @@ -26,7 +25,11 @@ class PythonExecutor: while queue: node = queue.pop() context = self.context( - {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } ) running.add(node) @@ -76,13 +79,10 @@ class PythonExecutor: return Table(expression.alias_or_name for expression in expressions) def scan(self, step, context): - if hasattr(step, "source"): - source = step.source + source = step.source - if isinstance(source, exp.Expression): - source = source.name or source.alias - else: - source = step.name + if isinstance(source, exp.Expression): + source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) @@ -96,14 +96,12 @@ class PythonExecutor: if projections: sink = self.table(step.projections) - elif source in context: - sink = Table(context[source].columns) else: sink = None for reader, ctx in table_iter: if sink is None: - sink = Table(ctx[source].columns) + sink = Table(reader.columns) if condition and not ctx.eval(condition): continue @@ -135,98 +133,79 @@ class PythonExecutor: types.append(type(ast.literal_eval(v))) except (ValueError, SyntaxError): types.append(str) - context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) - yield context[alias], context + context.set_row(tuple(t(v) for t, v in zip(types, row))) + yield context.table.reader, context def join(self, step, context): source = step.name - join_context = self.context({source: context.tables[source]}) - - def merge_context(ctx, table): - # create a new context where all existing tables are mapped to a new one - return self.context({name: table for name in ctx.tables}) + source_table = context.tables[source] + source_context = self.context({source: source_table}) + column_ranges = {source: range(0, len(source_table.columns))} for name, join in step.joins.items(): - join_context = self.context({**join_context.tables, name: context.tables[name]}) + table = context.tables[name] + start = max(r.stop for r in column_ranges.values()) + column_ranges[name] = range(start, len(table.columns) + start) + join_context = self.context({name: table}) if join.get("source_key"): - table = self.hash_join(join, source, name, join_context) + table = self.hash_join(join, source_context, join_context) else: - table = self.nested_loop_join(join, source, name, join_context) + table = self.nested_loop_join(join, source_context, join_context) - join_context = merge_context(join_context, table) - - # apply projections or conditions - context = self.scan(step, join_context) + source_context = self.context( + { + name: Table(table.columns, table.rows, column_range) + for name, column_range in column_ranges.items() + } + ) - # use the scan context since it returns a single table - # otherwise there are no projections so all other tables are still in scope - if step.projections: - return context + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) - return merge_context(join_context, context.tables[source]) + if not condition or not projections: + return source_context - def nested_loop_join(self, _join, a, b, context): - table = Table(context.tables[a].columns + context.tables[b].columns) + sink = self.table(step.projections if projections else source_context.columns) - for reader_a, _ in context.table_iter(a): - for reader_b, _ in context.table_iter(b): - table.append(reader_a.row + reader_b.row) + for reader, ctx in join_context: + if condition and not ctx.eval(condition): + continue - return table + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) - def hash_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) + if len(sink) >= step.limit: + break - results = collections.defaultdict(lambda: ([], [])) + return self.context({step.name: sink}) - for reader, ctx in context.table_iter(a): - results[ctx.eval_tuple(a_key)][0].append(reader.row) - for reader, ctx in context.table_iter(b): - results[ctx.eval_tuple(b_key)][1].append(reader.row) + def nested_loop_join(self, _join, source_context, join_context): + table = Table(source_context.columns + join_context.columns) - table = Table(context.tables[a].columns + context.tables[b].columns) - for a_group, b_group in results.values(): - for a_row, b_row in itertools.product(a_group, b_group): - table.append(a_row + b_row) + for reader_a, _ in source_context: + for reader_b, _ in join_context: + table.append(reader_a.row + reader_b.row) return table - def sort_merge_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) - - context.sort(a, a_key) - context.sort(b, b_key) - - a_i = 0 - b_i = 0 - a_n = len(context.tables[a]) - b_n = len(context.tables[b]) - - table = Table(context.tables[a].columns + context.tables[b].columns) - - def get_key(source, key, i): - context.set_index(source, i) - return context.eval_tuple(key) + def hash_join(self, join, source_context, join_context): + source_key = self.generate_tuple(join["source_key"]) + join_key = self.generate_tuple(join["join_key"]) - while a_i < a_n and b_i < b_n: - key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) - - a_group = [] - - while a_i < a_n and key == get_key(a, a_key, a_i): - a_group.append(context[a].row) - a_i += 1 + results = collections.defaultdict(lambda: ([], [])) - b_group = [] + for reader, ctx in source_context: + results[ctx.eval_tuple(source_key)][0].append(reader.row) + for reader, ctx in join_context: + results[ctx.eval_tuple(join_key)][1].append(reader.row) - while b_i < b_n and key == get_key(b, b_key, b_i): - b_group.append(context[b].row) - b_i += 1 + table = Table(source_context.columns + join_context.columns) + for a_group, b_group in results.values(): for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) @@ -238,16 +217,18 @@ class PythonExecutor: aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) - context.sort(source, group_by) - - if step.operands: + if operands: source_table = context.tables[source] operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) - context = self.context({source: operand_table}) + context = self.context( + {None: operand_table, **{table: operand_table for table in context.tables}} + ) + + context.sort(group_by) group = None start = 0 @@ -256,15 +237,15 @@ class PythonExecutor: table = self.table(step.group + step.aggregations) for i in range(length): - context.set_index(source, i) + context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 if i == length - 1: - context.set_range(source, start, end - 1) + context.set_range(start, end - 1) elif key != group: - context.set_range(source, start, end - 2) + context.set_range(start, end - 2) else: continue @@ -272,13 +253,32 @@ class PythonExecutor: group = key start = end - 2 - return self.scan(step, self.context({source: table})) + context = self.context({step.name: table, **{name: table for name in context.tables}}) + + if step.projections: + return self.scan(step, context) + return context def sort(self, step, context): - table = list(context.tables)[0] - key = self.generate_tuple(step.key) - context.sort(table, key) - return self.scan(step, context) + projections = self.generate_tuple(step.projections) + + sink = self.table(step.projections) + + for reader, ctx in context: + sink.append(ctx.eval_tuple(projections)) + + context = self.context( + { + None: sink, + **{table: sink for table in context.tables}, + } + ) + context.sort(self.generate_tuple(step.key)) + + if not math.isinf(step.limit): + context.table.rows = context.table.rows[0 : step.limit] + + return self.context({step.name: context.table}) def _cast_py(self, expression): @@ -293,7 +293,7 @@ def _cast_py(self, expression): def _column_py(self, expression): - table = self.sql(expression, "table") + table = self.sql(expression, "table") or None this = self.sql(expression, "this") return f"scope[{table}][{this}]" @@ -319,10 +319,10 @@ def _ordered_py(self, expression): class Python(Dialect): - class Tokenizer(Tokenizer): - ESCAPE = "\\" + class Tokenizer(tokens.Tokenizer): + ESCAPES = ["\\"] - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 80674cb..6796740 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,10 +1,12 @@ class Table: - def __init__(self, *columns, rows=None): - self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + def __init__(self, columns, rows=None, column_range=None): + self.columns = tuple(columns) + self.column_range = column_range + self.reader = RowReader(self.columns, self.column_range) + self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) - self.reader = RowReader(self.columns) self.range_reader = RangeReader(self) def append(self, row): @@ -29,15 +31,22 @@ class Table: return self.reader def __repr__(self): - widths = {column: len(column) for column in self.columns} - lines = [" ".join(column for column in self.columns)] + columns = tuple( + column + for i, column in enumerate(self.columns) + if not self.column_range or i in self.column_range + ) + widths = {column: len(column) for column in columns} + lines = [" ".join(column for column in columns)] for i, row in enumerate(self): if i > 10: break lines.append( - " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns + ) ) return "\n".join(lines) @@ -70,8 +79,10 @@ class RangeReader: class RowReader: - def __init__(self, columns): - self.columns = {column: i for i, column in enumerate(columns)} + def __init__(self, columns, column_range=None): + self.columns = { + column: i for i, column in enumerate(columns) if not column_range or i in column_range + } self.row = None def __getitem__(self, column): |