summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor/context.py')
-rw-r--r--sqlglot/executor/context.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index d265a2c..393347b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -19,6 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
+ self._table = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -29,8 +30,27 @@ class Context:
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
+ @property
+ def table(self):
+ if self._table is None:
+ self._table = list(self.tables.values())[0]
+ for other in self.tables.values():
+ if self._table.columns != other.columns:
+ raise Exception(f"Columns are different.")
+ if len(self._table.rows) != len(other.rows):
+ raise Exception(f"Rows are different.")
+ return self._table
+
+ @property
+ def columns(self):
+ return self.table.columns
+
def __iter__(self):
- return self.table_iter(list(self.tables)[0])
+ self.env["scope"] = self.row_readers
+ for i in range(len(self.table.rows)):
+ for table in self.tables.values():
+ reader = table[i]
+ yield reader, self
def table_iter(self, table):
self.env["scope"] = self.row_readers
@@ -38,8 +58,8 @@ class Context:
for reader in self.tables[table]:
yield reader, self
- def sort(self, table, key):
- table = self.tables[table]
+ def sort(self, key):
+ table = self.table
def sort_key(row):
table.reader.row = row
@@ -47,20 +67,20 @@ class Context:
table.rows.sort(key=sort_key)
- def set_row(self, table, row):
- self.row_readers[table].row = row
+ def set_row(self, row):
+ for table in self.tables.values():
+ table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, table, index):
- self.row_readers[table].row = self.tables[table].rows[index]
+ def set_index(self, index):
+ for table in self.tables.values():
+ table[index]
self.env["scope"] = self.row_readers
- def set_range(self, table, start, end):
- self.range_readers[table].range = range(start, end)
+ def set_range(self, start, end):
+ for name in self.tables:
+ self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __getitem__(self, table):
- return self.env["scope"][table]
-
def __contains__(self, table):
return table in self.tables