From f2981e8e4d28233864f1ca06ecec45ab80bf9eae Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 19 Nov 2022 15:50:39 +0100 Subject: Merging upstream version 10.0.8. Signed-off-by: Daniel Baumann --- sqlglot/executor/__init__.py | 23 +++- sqlglot/executor/context.py | 47 ++++--- sqlglot/executor/env.py | 162 ++++++++++++++++++++++-- sqlglot/executor/python.py | 287 ++++++++++++++++++++++++++----------------- sqlglot/executor/table.py | 43 ++++++- 5 files changed, 412 insertions(+), 150 deletions(-) (limited to 'sqlglot/executor') diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index e765616..04621b5 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -1,20 +1,23 @@ import logging import time -from sqlglot import parse_one +from sqlglot import maybe_parse +from sqlglot.errors import ExecuteError from sqlglot.executor.python import PythonExecutor +from sqlglot.executor.table import Table, ensure_tables from sqlglot.optimizer import optimize from sqlglot.planner import Plan +from sqlglot.schema import ensure_schema logger = logging.getLogger("sqlglot") -def execute(sql, schema, read=None): +def execute(sql, schema=None, read=None, tables=None): """ Run a sql query against data. Args: - sql (str): a sql statement + sql (str|sqlglot.Expression): a sql statement schema (dict|sqlglot.optimizer.Schema): database schema. This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of the following forms: @@ -23,10 +26,20 @@ def execute(sql, schema, read=None): 3. {catalog: {db: {table: {col: type}}}} read (str): the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + tables (dict): additional tables to register. Returns: sqlglot.executor.Table: Simple columnar data structure. """ - expression = parse_one(sql, read=read) + tables = ensure_tables(tables) + if not schema: + schema = { + name: {column: type(table[0][column]).__name__ for column in table.columns} + for name, table in tables.mapping.items() + } + schema = ensure_schema(schema) + if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args: + raise ExecuteError("Tables must support the same table args as schema") + expression = maybe_parse(sql, dialect=read) now = time.time() expression = optimize(expression, schema, leave_tables_isolated=True) logger.debug("Optimization finished: %f", time.time() - now) @@ -34,6 +47,6 @@ def execute(sql, schema, read=None): plan = Plan(expression) logger.debug("Logical Plan: %s", plan) now = time.time() - result = PythonExecutor().execute(plan) + result = PythonExecutor(tables=tables).execute(plan) logger.debug("Query finished: %f", time.time() - now) return result diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index 393347b..e9ff75b 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -1,5 +1,12 @@ +from __future__ import annotations + +import typing as t + from sqlglot.executor.env import ENV +if t.TYPE_CHECKING: + from sqlglot.executor.table import Table, TableIter + class Context: """ @@ -12,14 +19,14 @@ class Context: evaluation of aggregation functions. """ - def __init__(self, tables, env=None): + def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None: """ Args - tables (dict): table_name -> Table, representing the scope of the current execution context - env (Optional[dict]): dictionary of functions within the execution context + tables: representing the scope of the current execution context. + env: dictionary of functions within the execution context. """ self.tables = tables - self._table = None + self._table: t.Optional[Table] = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -31,7 +38,7 @@ class Context: return tuple(self.eval(code) for code in codes) @property - def table(self): + def table(self) -> Table: if self._table is None: self._table = list(self.tables.values())[0] for other in self.tables.values(): @@ -41,8 +48,12 @@ class Context: raise Exception(f"Rows are different.") return self._table + def add_columns(self, *columns: str) -> None: + for table in self.tables.values(): + table.add_columns(*columns) + @property - def columns(self): + def columns(self) -> t.Tuple: return self.table.columns def __iter__(self): @@ -52,35 +63,39 @@ class Context: reader = table[i] yield reader, self - def table_iter(self, table): + def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]: self.env["scope"] = self.row_readers for reader in self.tables[table]: yield reader, self - def sort(self, key): - table = self.table + def filter(self, condition) -> None: + rows = [reader.row for reader, _ in self if self.eval(condition)] - def sort_key(row): - table.reader.row = row + for table in self.tables.values(): + table.rows = rows + + def sort(self, key) -> None: + def sort_key(row: t.Tuple) -> t.Tuple: + self.set_row(row) return self.eval_tuple(key) - table.rows.sort(key=sort_key) + self.table.rows.sort(key=sort_key) - def set_row(self, row): + def set_row(self, row: t.Tuple) -> None: for table in self.tables.values(): table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, index): + def set_index(self, index: int) -> None: for table in self.tables.values(): table[index] self.env["scope"] = self.row_readers - def set_range(self, start, end): + def set_range(self, start: int, end: int) -> None: for name in self.tables: self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __contains__(self, table): + def __contains__(self, table: str) -> bool: return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index bbe6c81..ed80cc9 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -1,7 +1,10 @@ import datetime +import inspect import re import statistics +from functools import wraps +from sqlglot import exp from sqlglot.helper import PYTHON_VERSION @@ -16,20 +19,153 @@ class reverse_key: return other.obj < self.obj +def filter_nulls(func): + @wraps(func) + def _func(values): + return func(v for v in values if v is not None) + + return _func + + +def null_if_any(*required): + """ + Decorator that makes a function return `None` if any of the `required` arguments are `None`. + + This also supports decoration with no arguments, e.g.: + + @null_if_any + def foo(a, b): ... + + In which case all arguments are required. + """ + f = None + if len(required) == 1 and callable(required[0]): + f = required[0] + required = () + + def decorator(func): + if required: + required_indices = [ + i for i, param in enumerate(inspect.signature(func).parameters) if param in required + ] + + def predicate(*args): + return any(args[i] is None for i in required_indices) + + else: + + def predicate(*args): + return any(a is None for a in args) + + @wraps(func) + def _func(*args): + if predicate(*args): + return None + return func(*args) + + return _func + + if f: + return decorator(f) + + return decorator + + +@null_if_any("substr", "this") +def str_position(substr, this, position=None): + position = position - 1 if position is not None else position + return this.find(substr, position) + 1 + + +@null_if_any("this") +def substring(this, start=None, length=None): + if start is None: + return this + elif start == 0: + return "" + elif start < 0: + start = len(this) + start + else: + start -= 1 + + end = None if length is None else start + length + + return this[start:end] + + +@null_if_any +def cast(this, to): + if to == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(this) + if to == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(this) + if to in exp.DataType.TEXT_TYPES: + return str(this) + if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: + return float(this) + if to in exp.DataType.NUMERIC_TYPES: + return int(this) + raise NotImplementedError(f"Casting to '{to}' not implemented.") + + +def ordered(this, desc, nulls_first): + if desc: + return reverse_key(this) + return this + + +@null_if_any +def interval(this, unit): + if unit == "DAY": + return datetime.timedelta(days=float(this)) + raise NotImplementedError + + ENV = { "__builtins__": {}, - "datetime": datetime, - "locals": locals, - "re": re, - "bool": bool, - "float": float, - "int": int, - "str": str, - "desc": reverse_key, - "SUM": sum, - "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore - "COUNT": lambda acc: sum(1 for e in acc if e is not None), - "MAX": max, - "MIN": min, + "exp": exp, + # aggs + "SUM": filter_nulls(sum), + "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore + "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)), + "MAX": filter_nulls(max), + "MIN": filter_nulls(min), + # scalar functions + "ABS": null_if_any(lambda this: abs(this)), + "ADD": null_if_any(lambda e, this: e + this), + "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), + "BITWISEAND": null_if_any(lambda this, e: this & e), + "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), + "BITWISEOR": null_if_any(lambda this, e: this | e), + "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e), + "BITWISEXOR": null_if_any(lambda this, e: this ^ e), + "CAST": cast, + "COALESCE": lambda *args: next((a for a in args if a is not None), None), + "CONCAT": null_if_any(lambda *args: "".join(args)), + "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DIV": null_if_any(lambda e, this: e / this), + "EQ": null_if_any(lambda this, e: this == e), + "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), + "GT": null_if_any(lambda this, e: this > e), + "GTE": null_if_any(lambda this, e: this >= e), + "IFNULL": lambda e, alt: alt if e is None else e, + "IF": lambda predicate, true, false: true if predicate else false, + "INTDIV": null_if_any(lambda e, this: e // this), + "INTERVAL": interval, + "LIKE": null_if_any( + lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this)) + ), + "LOWER": null_if_any(lambda arg: arg.lower()), + "LT": null_if_any(lambda this, e: this < e), + "LTE": null_if_any(lambda this, e: this <= e), + "MOD": null_if_any(lambda e, this: e % this), + "MUL": null_if_any(lambda e, this: e * this), + "NEQ": null_if_any(lambda this, e: this != e), + "ORD": null_if_any(ord), + "ORDERED": ordered, "POW": pow, + "STRPOSITION": str_position, + "SUB": null_if_any(lambda e, this: e - this), + "SUBSTRING": substring, + "UPPER": null_if_any(lambda arg: arg.upper()), } diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 7d1db32..cb2543c 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -5,16 +5,18 @@ import math from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.errors import ExecuteError from sqlglot.executor.context import Context from sqlglot.executor.env import ENV -from sqlglot.executor.table import Table -from sqlglot.helper import csv_reader +from sqlglot.executor.table import RowReader, Table +from sqlglot.helper import csv_reader, subclasses class PythonExecutor: - def __init__(self, env=None): - self.generator = Python().generator(identify=True) + def __init__(self, env=None, tables=None): + self.generator = Python().generator(identify=True, comments=False) self.env = {**ENV, **(env or {})} + self.tables = tables or {} def execute(self, plan): running = set() @@ -24,36 +26,41 @@ class PythonExecutor: while queue: node = queue.pop() - context = self.context( - { - name: table - for dep in node.dependencies - for name, table in contexts[dep].tables.items() - } - ) - running.add(node) - - if isinstance(node, planner.Scan): - contexts[node] = self.scan(node, context) - elif isinstance(node, planner.Aggregate): - contexts[node] = self.aggregate(node, context) - elif isinstance(node, planner.Join): - contexts[node] = self.join(node, context) - elif isinstance(node, planner.Sort): - contexts[node] = self.sort(node, context) - else: - raise NotImplementedError - - running.remove(node) - finished.add(node) - - for dep in node.dependents: - if dep not in running and all(d in contexts for d in dep.dependencies): - queue.add(dep) - - for dep in node.dependencies: - if all(d in finished for d in dep.dependents): - contexts.pop(dep) + try: + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + running.add(node) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + elif isinstance(node, planner.SetOperation): + contexts[node] = self.set_operation(node, context) + else: + raise NotImplementedError + + running.remove(node) + finished.add(node) + + for dep in node.dependents: + if dep not in running and all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + except Exception as e: + raise ExecuteError(f"Step '{node.id}' failed: {e}") from e root = plan.root return contexts[root].tables[root.name] @@ -76,38 +83,43 @@ class PythonExecutor: return Context(tables, env=self.env) def table(self, expressions): - return Table(expression.alias_or_name for expression in expressions) + return Table( + expression.alias_or_name if isinstance(expression, exp.Expression) else expression + for expression in expressions + ) def scan(self, step, context): source = step.source - if isinstance(source, exp.Expression): + if source and isinstance(source, exp.Expression): source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) - if source in context: + if source is None: + context, table_iter = self.static() + elif source in context: if not projections and not condition: return self.context({step.name: context.tables[source]}) table_iter = context.table_iter(source) - else: + elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV): table_iter = self.scan_csv(step) + context = next(table_iter) + else: + context, table_iter = self.scan_table(step) if projections: sink = self.table(step.projections) else: - sink = None - - for reader, ctx in table_iter: - if sink is None: - sink = Table(reader.columns) + sink = self.table(context.columns) - if condition and not ctx.eval(condition): + for reader in table_iter: + if condition and not context.eval(condition): continue if projections: - sink.append(ctx.eval_tuple(projections)) + sink.append(context.eval_tuple(projections)) else: sink.append(reader.row) @@ -116,14 +128,23 @@ class PythonExecutor: return self.context({step.name: sink}) + def static(self): + return self.context({}), [RowReader(())] + + def scan_table(self, step): + table = self.tables.find(step.source) + context = self.context({step.source.alias_or_name: table}) + return context, iter(table) + def scan_csv(self, step): - source = step.source - alias = source.alias + alias = step.source.alias + source = step.source.this with csv_reader(source) as reader: columns = next(reader) table = Table(columns) context = self.context({alias: table}) + yield context types = [] for row in reader: @@ -134,7 +155,7 @@ class PythonExecutor: except (ValueError, SyntaxError): types.append(str) context.set_row(tuple(t(v) for t, v in zip(types, row))) - yield context.table.reader, context + yield context.table.reader def join(self, step, context): source = step.name @@ -160,16 +181,19 @@ class PythonExecutor: for name, column_range in column_ranges.items() } ) + condition = self.generate(join["condition"]) + if condition: + source_context.filter(condition) condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) - if not condition or not projections: + if not condition and not projections: return source_context sink = self.table(step.projections if projections else source_context.columns) - for reader, ctx in join_context: + for reader, ctx in source_context: if condition and not ctx.eval(condition): continue @@ -181,7 +205,15 @@ class PythonExecutor: if len(sink) >= step.limit: break - return self.context({step.name: sink}) + if projections: + return self.context({step.name: sink}) + else: + return self.context( + { + name: Table(table.columns, sink.rows, table.column_range) + for name, table in source_context.tables.items() + } + ) def nested_loop_join(self, _join, source_context, join_context): table = Table(source_context.columns + join_context.columns) @@ -195,6 +227,8 @@ class PythonExecutor: def hash_join(self, join, source_context, join_context): source_key = self.generate_tuple(join["source_key"]) join_key = self.generate_tuple(join["join_key"]) + left = join.get("side") == "LEFT" + right = join.get("side") == "RIGHT" results = collections.defaultdict(lambda: ([], [])) @@ -204,28 +238,47 @@ class PythonExecutor: results[ctx.eval_tuple(join_key)][1].append(reader.row) table = Table(source_context.columns + join_context.columns) + nulls = [(None,) * len(join_context.columns if left else source_context.columns)] for a_group, b_group in results.values(): + if left: + b_group = b_group or nulls + elif right: + a_group = a_group or nulls + for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) return table def aggregate(self, step, context): - source = step.source - group_by = self.generate_tuple(step.group) + group_by = self.generate_tuple(step.group.values()) aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) if operands: - source_table = context.tables[source] - operand_table = Table(source_table.columns + self.table(step.operands).columns) + operand_table = Table(self.table(step.operands).columns) for reader, ctx in context: - operand_table.append(reader.row + ctx.eval_tuple(operands)) + operand_table.append(ctx.eval_tuple(operands)) + + for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)): + context.table.rows[i] = a + b + + width = len(context.columns) + context.add_columns(*operand_table.columns) + + operand_table = Table( + context.columns, + context.table.rows, + range(width, width + len(operand_table.columns)), + ) context = self.context( - {None: operand_table, **{table: operand_table for table in context.tables}} + { + None: operand_table, + **context.tables, + } ) context.sort(group_by) @@ -233,25 +286,22 @@ class PythonExecutor: group = None start = 0 end = 1 - length = len(context.tables[source]) - table = self.table(step.group + step.aggregations) + length = len(context.table) + table = self.table(list(step.group) + step.aggregations) for i in range(length): context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 - + if key != group: + context.set_range(start, end - 2) + table.append(group + context.eval_tuple(aggregations)) + group = key + start = end - 2 if i == length - 1: context.set_range(start, end - 1) - elif key != group: - context.set_range(start, end - 2) - else: - continue - - table.append(group + context.eval_tuple(aggregations)) - group = key - start = end - 2 + table.append(group + context.eval_tuple(aggregations)) context = self.context({step.name: table, **{name: table for name in context.tables}}) @@ -262,60 +312,77 @@ class PythonExecutor: def sort(self, step, context): projections = self.generate_tuple(step.projections) - sink = self.table(step.projections) + projection_columns = [p.alias_or_name for p in step.projections] + all_columns = list(context.columns) + projection_columns + sink = self.table(all_columns) for reader, ctx in context: - sink.append(ctx.eval_tuple(projections)) + sink.append(reader.row + ctx.eval_tuple(projections)) - context = self.context( + sort_ctx = self.context( { None: sink, **{table: sink for table in context.tables}, } ) - context.sort(self.generate_tuple(step.key)) + sort_ctx.sort(self.generate_tuple(step.key)) if not math.isinf(step.limit): - context.table.rows = context.table.rows[0 : step.limit] + sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit] - return self.context({step.name: context.table}) + output = Table( + projection_columns, + rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows], + ) + return self.context({step.name: output}) + def set_operation(self, step, context): + left = context.tables[step.left] + right = context.tables[step.right] -def _cast_py(self, expression): - to = expression.args["to"].this - this = self.sql(expression, "this") + sink = self.table(left.columns) + + if issubclass(step.op, exp.Intersect): + sink.rows = list(set(left.rows).intersection(set(right.rows))) + elif issubclass(step.op, exp.Except): + sink.rows = list(set(left.rows).difference(set(right.rows))) + elif issubclass(step.op, exp.Union) and step.distinct: + sink.rows = list(set(left.rows).union(set(right.rows))) + else: + sink.rows = left.rows + right.rows - if to == exp.DataType.Type.DATE: - return f"datetime.date.fromisoformat({this})" - if to == exp.DataType.Type.TEXT: - return f"str({this})" - raise NotImplementedError + return self.context({step.name: sink}) -def _column_py(self, expression): - table = self.sql(expression, "table") or None +def _ordered_py(self, expression): this = self.sql(expression, "this") - return f"scope[{table}][{this}]" + desc = "True" if expression.args.get("desc") else "False" + nulls_first = "True" if expression.args.get("nulls_first") else "False" + return f"ORDERED({this}, {desc}, {nulls_first})" -def _interval_py(self, expression): - this = self.sql(expression, "this") - unit = expression.text("unit").upper() - if unit == "DAY": - return f"datetime.timedelta(days=float({this}))" - raise NotImplementedError +def _rename(self, e): + try: + if "expressions" in e.args: + this = self.sql(e, "this") + this = f"{this}, " if this else "" + return f"{e.key.upper()}({this}{self.expressions(e)})" + return f"{e.key.upper()}({self.format_args(*e.args.values())})" + except Exception as ex: + raise Exception(f"Could not rename {repr(e)}") from ex -def _like_py(self, expression): +def _case_sql(self, expression): this = self.sql(expression, "this") - expression = self.sql(expression, "expression") - return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))""" + chain = self.sql(expression, "default") or "None" + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" -def _ordered_py(self, expression): - this = self.sql(expression, "this") - desc = expression.args.get("desc") - return f"desc({this})" if desc else this + return chain class Python(Dialect): @@ -324,32 +391,22 @@ class Python(Dialect): class Generator(generator.Generator): TRANSFORMS = { + **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)}, + **{klass: _rename for klass in exp.ALL_FUNCTIONS}, + exp.Case: _case_sql, exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, exp.And: lambda self, e: self.binary(e, "and"), + exp.Between: _rename, exp.Boolean: lambda self, e: "True" if e.this else "False", - exp.Cast: _cast_py, - exp.Column: _column_py, - exp.EQ: lambda self, e: self.binary(e, "=="), + exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", + exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", - exp.Interval: _interval_py, exp.Is: lambda self, e: self.binary(e, "is"), - exp.Like: _like_py, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", exp.Or: lambda self, e: self.binary(e, "or"), exp.Ordered: _ordered_py, exp.Star: lambda *_: "1", } - - def case_sql(self, expression): - this = self.sql(expression, "this") - chain = self.sql(expression, "default") or "None" - - for e in reversed(expression.args["ifs"]): - true = self.sql(e, "true") - condition = self.sql(e, "this") - condition = f"{this} = ({condition})" if this else condition - chain = f"{true} if {condition} else ({chain})" - - return chain diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 6796740..f1b5b54 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,14 +1,27 @@ +from __future__ import annotations + +from sqlglot.helper import dict_depth +from sqlglot.schema import AbstractMappingSchema + + class Table: def __init__(self, columns, rows=None, column_range=None): self.columns = tuple(columns) self.column_range = column_range self.reader = RowReader(self.columns, self.column_range) - self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) self.range_reader = RangeReader(self) + def add_columns(self, *columns: str) -> None: + self.columns += columns + if self.column_range: + self.column_range = range( + self.column_range.start, self.column_range.stop + len(columns) + ) + self.reader = RowReader(self.columns, self.column_range) + def append(self, row): assert len(row) == len(self.columns) self.rows.append(row) @@ -87,3 +100,31 @@ class RowReader: def __getitem__(self, column): return self.row[self.columns[column]] + + +class Tables(AbstractMappingSchema[Table]): + pass + + +def ensure_tables(d: dict | None) -> Tables: + return Tables(_ensure_tables(d)) + + +def _ensure_tables(d: dict | None) -> dict: + if not d: + return {} + + depth = dict_depth(d) + + if depth > 1: + return {k: _ensure_tables(v) for k, v in d.items()} + + result = {} + for name, table in d.items(): + if isinstance(table, Table): + result[name] = table + else: + columns = tuple(table[0]) if table else () + rows = [tuple(row[c] for c in columns) for row in table] + result[name] = Table(columns=columns, rows=rows) + return result -- cgit v1.2.3