summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor/table.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor/table.py')
-rw-r--r--sqlglot/executor/table.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 6796740..f1b5b54 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,14 +1,27 @@
+from __future__ import annotations
+
+from sqlglot.helper import dict_depth
+from sqlglot.schema import AbstractMappingSchema
+
+
class Table:
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
-
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.range_reader = RangeReader(self)
+ def add_columns(self, *columns: str) -> None:
+ self.columns += columns
+ if self.column_range:
+ self.column_range = range(
+ self.column_range.start, self.column_range.stop + len(columns)
+ )
+ self.reader = RowReader(self.columns, self.column_range)
+
def append(self, row):
assert len(row) == len(self.columns)
self.rows.append(row)
@@ -87,3 +100,31 @@ class RowReader:
def __getitem__(self, column):
return self.row[self.columns[column]]
+
+
+class Tables(AbstractMappingSchema[Table]):
+ pass
+
+
+def ensure_tables(d: dict | None) -> Tables:
+ return Tables(_ensure_tables(d))
+
+
+def _ensure_tables(d: dict | None) -> dict:
+ if not d:
+ return {}
+
+ depth = dict_depth(d)
+
+ if depth > 1:
+ return {k: _ensure_tables(v) for k, v in d.items()}
+
+ result = {}
+ for name, table in d.items():
+ if isinstance(table, Table):
+ result[name] = table
+ else:
+ columns = tuple(table[0]) if table else ()
+ rows = [tuple(row[c] for c in columns) for row in table]
+ result[name] = Table(columns=columns, rows=rows)
+ return result