summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor/python.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor/python.py')
-rw-r--r--sqlglot/executor/python.py190
1 files changed, 95 insertions, 95 deletions
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index fcb016b..7d1db32 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -1,15 +1,14 @@
import ast
import collections
import itertools
+import math
-from sqlglot import exp, planner
+from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table
-from sqlglot.generator import Generator
from sqlglot.helper import csv_reader
-from sqlglot.tokens import Tokenizer
class PythonExecutor:
@@ -26,7 +25,11 @@ class PythonExecutor:
while queue:
node = queue.pop()
context = self.context(
- {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
)
running.add(node)
@@ -76,13 +79,10 @@ class PythonExecutor:
return Table(expression.alias_or_name for expression in expressions)
def scan(self, step, context):
- if hasattr(step, "source"):
- source = step.source
+ source = step.source
- if isinstance(source, exp.Expression):
- source = source.name or source.alias
- else:
- source = step.name
+ if isinstance(source, exp.Expression):
+ source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
@@ -96,14 +96,12 @@ class PythonExecutor:
if projections:
sink = self.table(step.projections)
- elif source in context:
- sink = Table(context[source].columns)
else:
sink = None
for reader, ctx in table_iter:
if sink is None:
- sink = Table(ctx[source].columns)
+ sink = Table(reader.columns)
if condition and not ctx.eval(condition):
continue
@@ -135,98 +133,79 @@ class PythonExecutor:
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
- context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
- yield context[alias], context
+ context.set_row(tuple(t(v) for t, v in zip(types, row)))
+ yield context.table.reader, context
def join(self, step, context):
source = step.name
- join_context = self.context({source: context.tables[source]})
-
- def merge_context(ctx, table):
- # create a new context where all existing tables are mapped to a new one
- return self.context({name: table for name in ctx.tables})
+ source_table = context.tables[source]
+ source_context = self.context({source: source_table})
+ column_ranges = {source: range(0, len(source_table.columns))}
for name, join in step.joins.items():
- join_context = self.context({**join_context.tables, name: context.tables[name]})
+ table = context.tables[name]
+ start = max(r.stop for r in column_ranges.values())
+ column_ranges[name] = range(start, len(table.columns) + start)
+ join_context = self.context({name: table})
if join.get("source_key"):
- table = self.hash_join(join, source, name, join_context)
+ table = self.hash_join(join, source_context, join_context)
else:
- table = self.nested_loop_join(join, source, name, join_context)
+ table = self.nested_loop_join(join, source_context, join_context)
- join_context = merge_context(join_context, table)
-
- # apply projections or conditions
- context = self.scan(step, join_context)
+ source_context = self.context(
+ {
+ name: Table(table.columns, table.rows, column_range)
+ for name, column_range in column_ranges.items()
+ }
+ )
- # use the scan context since it returns a single table
- # otherwise there are no projections so all other tables are still in scope
- if step.projections:
- return context
+ condition = self.generate(step.condition)
+ projections = self.generate_tuple(step.projections)
- return merge_context(join_context, context.tables[source])
+ if not condition or not projections:
+ return source_context
- def nested_loop_join(self, _join, a, b, context):
- table = Table(context.tables[a].columns + context.tables[b].columns)
+ sink = self.table(step.projections if projections else source_context.columns)
- for reader_a, _ in context.table_iter(a):
- for reader_b, _ in context.table_iter(b):
- table.append(reader_a.row + reader_b.row)
+ for reader, ctx in join_context:
+ if condition and not ctx.eval(condition):
+ continue
- return table
+ if projections:
+ sink.append(ctx.eval_tuple(projections))
+ else:
+ sink.append(reader.row)
- def hash_join(self, join, a, b, context):
- a_key = self.generate_tuple(join["source_key"])
- b_key = self.generate_tuple(join["join_key"])
+ if len(sink) >= step.limit:
+ break
- results = collections.defaultdict(lambda: ([], []))
+ return self.context({step.name: sink})
- for reader, ctx in context.table_iter(a):
- results[ctx.eval_tuple(a_key)][0].append(reader.row)
- for reader, ctx in context.table_iter(b):
- results[ctx.eval_tuple(b_key)][1].append(reader.row)
+ def nested_loop_join(self, _join, source_context, join_context):
+ table = Table(source_context.columns + join_context.columns)
- table = Table(context.tables[a].columns + context.tables[b].columns)
- for a_group, b_group in results.values():
- for a_row, b_row in itertools.product(a_group, b_group):
- table.append(a_row + b_row)
+ for reader_a, _ in source_context:
+ for reader_b, _ in join_context:
+ table.append(reader_a.row + reader_b.row)
return table
- def sort_merge_join(self, join, a, b, context):
- a_key = self.generate_tuple(join["source_key"])
- b_key = self.generate_tuple(join["join_key"])
-
- context.sort(a, a_key)
- context.sort(b, b_key)
-
- a_i = 0
- b_i = 0
- a_n = len(context.tables[a])
- b_n = len(context.tables[b])
-
- table = Table(context.tables[a].columns + context.tables[b].columns)
-
- def get_key(source, key, i):
- context.set_index(source, i)
- return context.eval_tuple(key)
+ def hash_join(self, join, source_context, join_context):
+ source_key = self.generate_tuple(join["source_key"])
+ join_key = self.generate_tuple(join["join_key"])
- while a_i < a_n and b_i < b_n:
- key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
-
- a_group = []
-
- while a_i < a_n and key == get_key(a, a_key, a_i):
- a_group.append(context[a].row)
- a_i += 1
+ results = collections.defaultdict(lambda: ([], []))
- b_group = []
+ for reader, ctx in source_context:
+ results[ctx.eval_tuple(source_key)][0].append(reader.row)
+ for reader, ctx in join_context:
+ results[ctx.eval_tuple(join_key)][1].append(reader.row)
- while b_i < b_n and key == get_key(b, b_key, b_i):
- b_group.append(context[b].row)
- b_i += 1
+ table = Table(source_context.columns + join_context.columns)
+ for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
@@ -238,16 +217,18 @@ class PythonExecutor:
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
- context.sort(source, group_by)
-
- if step.operands:
+ if operands:
source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
- context = self.context({source: operand_table})
+ context = self.context(
+ {None: operand_table, **{table: operand_table for table in context.tables}}
+ )
+
+ context.sort(group_by)
group = None
start = 0
@@ -256,15 +237,15 @@ class PythonExecutor:
table = self.table(step.group + step.aggregations)
for i in range(length):
- context.set_index(source, i)
+ context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if i == length - 1:
- context.set_range(source, start, end - 1)
+ context.set_range(start, end - 1)
elif key != group:
- context.set_range(source, start, end - 2)
+ context.set_range(start, end - 2)
else:
continue
@@ -272,13 +253,32 @@ class PythonExecutor:
group = key
start = end - 2
- return self.scan(step, self.context({source: table}))
+ context = self.context({step.name: table, **{name: table for name in context.tables}})
+
+ if step.projections:
+ return self.scan(step, context)
+ return context
def sort(self, step, context):
- table = list(context.tables)[0]
- key = self.generate_tuple(step.key)
- context.sort(table, key)
- return self.scan(step, context)
+ projections = self.generate_tuple(step.projections)
+
+ sink = self.table(step.projections)
+
+ for reader, ctx in context:
+ sink.append(ctx.eval_tuple(projections))
+
+ context = self.context(
+ {
+ None: sink,
+ **{table: sink for table in context.tables},
+ }
+ )
+ context.sort(self.generate_tuple(step.key))
+
+ if not math.isinf(step.limit):
+ context.table.rows = context.table.rows[0 : step.limit]
+
+ return self.context({step.name: context.table})
def _cast_py(self, expression):
@@ -293,7 +293,7 @@ def _cast_py(self, expression):
def _column_py(self, expression):
- table = self.sql(expression, "table")
+ table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
@@ -319,10 +319,10 @@ def _ordered_py(self, expression):
class Python(Dialect):
- class Tokenizer(Tokenizer):
- ESCAPE = "\\"
+ class Tokenizer(tokens.Tokenizer):
+ ESCAPES = ["\\"]
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,