summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor')
-rw-r--r--sqlglot/executor/context.py44
-rw-r--r--sqlglot/executor/env.py4
-rw-r--r--sqlglot/executor/python.py190
-rw-r--r--sqlglot/executor/table.py27
4 files changed, 149 insertions, 116 deletions
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index d265a2c..393347b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -19,6 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
+ self._table = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -29,8 +30,27 @@ class Context:
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
+ @property
+ def table(self):
+ if self._table is None:
+ self._table = list(self.tables.values())[0]
+ for other in self.tables.values():
+ if self._table.columns != other.columns:
+ raise Exception(f"Columns are different.")
+ if len(self._table.rows) != len(other.rows):
+ raise Exception(f"Rows are different.")
+ return self._table
+
+ @property
+ def columns(self):
+ return self.table.columns
+
def __iter__(self):
- return self.table_iter(list(self.tables)[0])
+ self.env["scope"] = self.row_readers
+ for i in range(len(self.table.rows)):
+ for table in self.tables.values():
+ reader = table[i]
+ yield reader, self
def table_iter(self, table):
self.env["scope"] = self.row_readers
@@ -38,8 +58,8 @@ class Context:
for reader in self.tables[table]:
yield reader, self
- def sort(self, table, key):
- table = self.tables[table]
+ def sort(self, key):
+ table = self.table
def sort_key(row):
table.reader.row = row
@@ -47,20 +67,20 @@ class Context:
table.rows.sort(key=sort_key)
- def set_row(self, table, row):
- self.row_readers[table].row = row
+ def set_row(self, row):
+ for table in self.tables.values():
+ table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, table, index):
- self.row_readers[table].row = self.tables[table].rows[index]
+ def set_index(self, index):
+ for table in self.tables.values():
+ table[index]
self.env["scope"] = self.row_readers
- def set_range(self, table, start, end):
- self.range_readers[table].range = range(start, end)
+ def set_range(self, start, end):
+ for name in self.tables:
+ self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __getitem__(self, table):
- return self.env["scope"][table]
-
def __contains__(self, table):
return table in self.tables
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 9c49dd1..bbe6c81 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -2,6 +2,8 @@ import datetime
import re
import statistics
+from sqlglot.helper import PYTHON_VERSION
+
class reverse_key:
def __init__(self, obj):
@@ -25,7 +27,7 @@ ENV = {
"str": str,
"desc": reverse_key,
"SUM": sum,
- "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
+ "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
"MAX": max,
"MIN": min,
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index fcb016b..7d1db32 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -1,15 +1,14 @@
import ast
import collections
import itertools
+import math
-from sqlglot import exp, planner
+from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table
-from sqlglot.generator import Generator
from sqlglot.helper import csv_reader
-from sqlglot.tokens import Tokenizer
class PythonExecutor:
@@ -26,7 +25,11 @@ class PythonExecutor:
while queue:
node = queue.pop()
context = self.context(
- {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
)
running.add(node)
@@ -76,13 +79,10 @@ class PythonExecutor:
return Table(expression.alias_or_name for expression in expressions)
def scan(self, step, context):
- if hasattr(step, "source"):
- source = step.source
+ source = step.source
- if isinstance(source, exp.Expression):
- source = source.name or source.alias
- else:
- source = step.name
+ if isinstance(source, exp.Expression):
+ source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
@@ -96,14 +96,12 @@ class PythonExecutor:
if projections:
sink = self.table(step.projections)
- elif source in context:
- sink = Table(context[source].columns)
else:
sink = None
for reader, ctx in table_iter:
if sink is None:
- sink = Table(ctx[source].columns)
+ sink = Table(reader.columns)
if condition and not ctx.eval(condition):
continue
@@ -135,98 +133,79 @@ class PythonExecutor:
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
- context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
- yield context[alias], context
+ context.set_row(tuple(t(v) for t, v in zip(types, row)))
+ yield context.table.reader, context
def join(self, step, context):
source = step.name
- join_context = self.context({source: context.tables[source]})
-
- def merge_context(ctx, table):
- # create a new context where all existing tables are mapped to a new one
- return self.context({name: table for name in ctx.tables})
+ source_table = context.tables[source]
+ source_context = self.context({source: source_table})
+ column_ranges = {source: range(0, len(source_table.columns))}
for name, join in step.joins.items():
- join_context = self.context({**join_context.tables, name: context.tables[name]})
+ table = context.tables[name]
+ start = max(r.stop for r in column_ranges.values())
+ column_ranges[name] = range(start, len(table.columns) + start)
+ join_context = self.context({name: table})
if join.get("source_key"):
- table = self.hash_join(join, source, name, join_context)
+ table = self.hash_join(join, source_context, join_context)
else:
- table = self.nested_loop_join(join, source, name, join_context)
+ table = self.nested_loop_join(join, source_context, join_context)
- join_context = merge_context(join_context, table)
-
- # apply projections or conditions
- context = self.scan(step, join_context)
+ source_context = self.context(
+ {
+ name: Table(table.columns, table.rows, column_range)
+ for name, column_range in column_ranges.items()
+ }
+ )
- # use the scan context since it returns a single table
- # otherwise there are no projections so all other tables are still in scope
- if step.projections:
- return context
+ condition = self.generate(step.condition)
+ projections = self.generate_tuple(step.projections)
- return merge_context(join_context, context.tables[source])
+ if not condition or not projections:
+ return source_context
- def nested_loop_join(self, _join, a, b, context):
- table = Table(context.tables[a].columns + context.tables[b].columns)
+ sink = self.table(step.projections if projections else source_context.columns)
- for reader_a, _ in context.table_iter(a):
- for reader_b, _ in context.table_iter(b):
- table.append(reader_a.row + reader_b.row)
+ for reader, ctx in join_context:
+ if condition and not ctx.eval(condition):
+ continue
- return table
+ if projections:
+ sink.append(ctx.eval_tuple(projections))
+ else:
+ sink.append(reader.row)
- def hash_join(self, join, a, b, context):
- a_key = self.generate_tuple(join["source_key"])
- b_key = self.generate_tuple(join["join_key"])
+ if len(sink) >= step.limit:
+ break
- results = collections.defaultdict(lambda: ([], []))
+ return self.context({step.name: sink})
- for reader, ctx in context.table_iter(a):
- results[ctx.eval_tuple(a_key)][0].append(reader.row)
- for reader, ctx in context.table_iter(b):
- results[ctx.eval_tuple(b_key)][1].append(reader.row)
+ def nested_loop_join(self, _join, source_context, join_context):
+ table = Table(source_context.columns + join_context.columns)
- table = Table(context.tables[a].columns + context.tables[b].columns)
- for a_group, b_group in results.values():
- for a_row, b_row in itertools.product(a_group, b_group):
- table.append(a_row + b_row)
+ for reader_a, _ in source_context:
+ for reader_b, _ in join_context:
+ table.append(reader_a.row + reader_b.row)
return table
- def sort_merge_join(self, join, a, b, context):
- a_key = self.generate_tuple(join["source_key"])
- b_key = self.generate_tuple(join["join_key"])
-
- context.sort(a, a_key)
- context.sort(b, b_key)
-
- a_i = 0
- b_i = 0
- a_n = len(context.tables[a])
- b_n = len(context.tables[b])
-
- table = Table(context.tables[a].columns + context.tables[b].columns)
-
- def get_key(source, key, i):
- context.set_index(source, i)
- return context.eval_tuple(key)
+ def hash_join(self, join, source_context, join_context):
+ source_key = self.generate_tuple(join["source_key"])
+ join_key = self.generate_tuple(join["join_key"])
- while a_i < a_n and b_i < b_n:
- key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
-
- a_group = []
-
- while a_i < a_n and key == get_key(a, a_key, a_i):
- a_group.append(context[a].row)
- a_i += 1
+ results = collections.defaultdict(lambda: ([], []))
- b_group = []
+ for reader, ctx in source_context:
+ results[ctx.eval_tuple(source_key)][0].append(reader.row)
+ for reader, ctx in join_context:
+ results[ctx.eval_tuple(join_key)][1].append(reader.row)
- while b_i < b_n and key == get_key(b, b_key, b_i):
- b_group.append(context[b].row)
- b_i += 1
+ table = Table(source_context.columns + join_context.columns)
+ for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
@@ -238,16 +217,18 @@ class PythonExecutor:
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
- context.sort(source, group_by)
-
- if step.operands:
+ if operands:
source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
- context = self.context({source: operand_table})
+ context = self.context(
+ {None: operand_table, **{table: operand_table for table in context.tables}}
+ )
+
+ context.sort(group_by)
group = None
start = 0
@@ -256,15 +237,15 @@ class PythonExecutor:
table = self.table(step.group + step.aggregations)
for i in range(length):
- context.set_index(source, i)
+ context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if i == length - 1:
- context.set_range(source, start, end - 1)
+ context.set_range(start, end - 1)
elif key != group:
- context.set_range(source, start, end - 2)
+ context.set_range(start, end - 2)
else:
continue
@@ -272,13 +253,32 @@ class PythonExecutor:
group = key
start = end - 2
- return self.scan(step, self.context({source: table}))
+ context = self.context({step.name: table, **{name: table for name in context.tables}})
+
+ if step.projections:
+ return self.scan(step, context)
+ return context
def sort(self, step, context):
- table = list(context.tables)[0]
- key = self.generate_tuple(step.key)
- context.sort(table, key)
- return self.scan(step, context)
+ projections = self.generate_tuple(step.projections)
+
+ sink = self.table(step.projections)
+
+ for reader, ctx in context:
+ sink.append(ctx.eval_tuple(projections))
+
+ context = self.context(
+ {
+ None: sink,
+ **{table: sink for table in context.tables},
+ }
+ )
+ context.sort(self.generate_tuple(step.key))
+
+ if not math.isinf(step.limit):
+ context.table.rows = context.table.rows[0 : step.limit]
+
+ return self.context({step.name: context.table})
def _cast_py(self, expression):
@@ -293,7 +293,7 @@ def _cast_py(self, expression):
def _column_py(self, expression):
- table = self.sql(expression, "table")
+ table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
@@ -319,10 +319,10 @@ def _ordered_py(self, expression):
class Python(Dialect):
- class Tokenizer(Tokenizer):
- ESCAPE = "\\"
+ class Tokenizer(tokens.Tokenizer):
+ ESCAPES = ["\\"]
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 80674cb..6796740 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,10 +1,12 @@
class Table:
- def __init__(self, *columns, rows=None):
- self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
+ def __init__(self, columns, rows=None, column_range=None):
+ self.columns = tuple(columns)
+ self.column_range = column_range
+ self.reader = RowReader(self.columns, self.column_range)
+
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
- self.reader = RowReader(self.columns)
self.range_reader = RangeReader(self)
def append(self, row):
@@ -29,15 +31,22 @@ class Table:
return self.reader
def __repr__(self):
- widths = {column: len(column) for column in self.columns}
- lines = [" ".join(column for column in self.columns)]
+ columns = tuple(
+ column
+ for i, column in enumerate(self.columns)
+ if not self.column_range or i in self.column_range
+ )
+ widths = {column: len(column) for column in columns}
+ lines = [" ".join(column for column in columns)]
for i, row in enumerate(self):
if i > 10:
break
lines.append(
- " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
+ " ".join(
+ str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
+ )
)
return "\n".join(lines)
@@ -70,8 +79,10 @@ class RangeReader:
class RowReader:
- def __init__(self, columns):
- self.columns = {column: i for i, column in enumerate(columns)}
+ def __init__(self, columns, column_range=None):
+ self.columns = {
+ column: i for i, column in enumerate(columns) if not column_range or i in column_range
+ }
self.row = None
def __getitem__(self, column):