summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor/context.py')
-rw-r--r--sqlglot/executor/context.py47
1 files changed, 31 insertions, 16 deletions
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 393347b..e9ff75b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -1,5 +1,12 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot.executor.env import ENV
+if t.TYPE_CHECKING:
+ from sqlglot.executor.table import Table, TableIter
+
class Context:
"""
@@ -12,14 +19,14 @@ class Context:
evaluation of aggregation functions.
"""
- def __init__(self, tables, env=None):
+ def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
"""
Args
- tables (dict): table_name -> Table, representing the scope of the current execution context
- env (Optional[dict]): dictionary of functions within the execution context
+ tables: representing the scope of the current execution context.
+ env: dictionary of functions within the execution context.
"""
self.tables = tables
- self._table = None
+ self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -31,7 +38,7 @@ class Context:
return tuple(self.eval(code) for code in codes)
@property
- def table(self):
+ def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
@@ -41,8 +48,12 @@ class Context:
raise Exception(f"Rows are different.")
return self._table
+ def add_columns(self, *columns: str) -> None:
+ for table in self.tables.values():
+ table.add_columns(*columns)
+
@property
- def columns(self):
+ def columns(self) -> t.Tuple:
return self.table.columns
def __iter__(self):
@@ -52,35 +63,39 @@ class Context:
reader = table[i]
yield reader, self
- def table_iter(self, table):
+ def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
- def sort(self, key):
- table = self.table
+ def filter(self, condition) -> None:
+ rows = [reader.row for reader, _ in self if self.eval(condition)]
- def sort_key(row):
- table.reader.row = row
+ for table in self.tables.values():
+ table.rows = rows
+
+ def sort(self, key) -> None:
+ def sort_key(row: t.Tuple) -> t.Tuple:
+ self.set_row(row)
return self.eval_tuple(key)
- table.rows.sort(key=sort_key)
+ self.table.rows.sort(key=sort_key)
- def set_row(self, row):
+ def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, index):
+ def set_index(self, index: int) -> None:
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
- def set_range(self, start, end):
+ def set_range(self, start: int, end: int) -> None:
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __contains__(self, table):
+ def __contains__(self, table: str) -> bool:
return table in self.tables