diff options
Diffstat (limited to 'sqlglot/executor/context.py')
-rw-r--r-- | sqlglot/executor/context.py | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index d265a2c..393347b 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,6 +19,7 @@ class Context: env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables + self._table = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -29,8 +30,27 @@ class Context: def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) + @property + def table(self): + if self._table is None: + self._table = list(self.tables.values())[0] + for other in self.tables.values(): + if self._table.columns != other.columns: + raise Exception(f"Columns are different.") + if len(self._table.rows) != len(other.rows): + raise Exception(f"Rows are different.") + return self._table + + @property + def columns(self): + return self.table.columns + def __iter__(self): - return self.table_iter(list(self.tables)[0]) + self.env["scope"] = self.row_readers + for i in range(len(self.table.rows)): + for table in self.tables.values(): + reader = table[i] + yield reader, self def table_iter(self, table): self.env["scope"] = self.row_readers @@ -38,8 +58,8 @@ class Context: for reader in self.tables[table]: yield reader, self - def sort(self, table, key): - table = self.tables[table] + def sort(self, key): + table = self.table def sort_key(row): table.reader.row = row @@ -47,20 +67,20 @@ class Context: table.rows.sort(key=sort_key) - def set_row(self, table, row): - self.row_readers[table].row = row + def set_row(self, row): + for table in self.tables.values(): + table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, table, index): - self.row_readers[table].row = self.tables[table].rows[index] + def set_index(self, index): + for table in self.tables.values(): + table[index] self.env["scope"] = self.row_readers - def set_range(self, table, start, end): - self.range_readers[table].range = range(start, end) + def set_range(self, start, end): + for name in self.tables: + self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __getitem__(self, table): - return self.env["scope"][table] - def __contains__(self, table): return table in self.tables |