diff options
Diffstat (limited to 'sqlglot/executor/table.py')
-rw-r--r-- | sqlglot/executor/table.py | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 6796740..f1b5b54 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,14 +1,27 @@ +from __future__ import annotations + +from sqlglot.helper import dict_depth +from sqlglot.schema import AbstractMappingSchema + + class Table: def __init__(self, columns, rows=None, column_range=None): self.columns = tuple(columns) self.column_range = column_range self.reader = RowReader(self.columns, self.column_range) - self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) self.range_reader = RangeReader(self) + def add_columns(self, *columns: str) -> None: + self.columns += columns + if self.column_range: + self.column_range = range( + self.column_range.start, self.column_range.stop + len(columns) + ) + self.reader = RowReader(self.columns, self.column_range) + def append(self, row): assert len(row) == len(self.columns) self.rows.append(row) @@ -87,3 +100,31 @@ class RowReader: def __getitem__(self, column): return self.row[self.columns[column]] + + +class Tables(AbstractMappingSchema[Table]): + pass + + +def ensure_tables(d: dict | None) -> Tables: + return Tables(_ensure_tables(d)) + + +def _ensure_tables(d: dict | None) -> dict: + if not d: + return {} + + depth = dict_depth(d) + + if depth > 1: + return {k: _ensure_tables(v) for k, v in d.items()} + + result = {} + for name, table in d.items(): + if isinstance(table, Table): + result[name] = table + else: + columns = tuple(table[0]) if table else () + rows = [tuple(row[c] for c in columns) for row in table] + result[name] = Table(columns=columns, rows=rows) + return result |