summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor')
-rw-r--r--sqlglot/executor/__init__.py23
-rw-r--r--sqlglot/executor/context.py47
-rw-r--r--sqlglot/executor/env.py162
-rw-r--r--sqlglot/executor/python.py287
-rw-r--r--sqlglot/executor/table.py43
5 files changed, 412 insertions, 150 deletions
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index e765616..04621b5 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -1,20 +1,23 @@
import logging
import time
-from sqlglot import parse_one
+from sqlglot import maybe_parse
+from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
+from sqlglot.executor.table import Table, ensure_tables
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
+from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot")
-def execute(sql, schema, read=None):
+def execute(sql, schema=None, read=None, tables=None):
"""
Run a sql query against data.
Args:
- sql (str): a sql statement
+ sql (str|sqlglot.Expression): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
@@ -23,10 +26,20 @@ def execute(sql, schema, read=None):
3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
+ tables (dict): additional tables to register.
Returns:
sqlglot.executor.Table: Simple columnar data structure.
"""
- expression = parse_one(sql, read=read)
+ tables = ensure_tables(tables)
+ if not schema:
+ schema = {
+ name: {column: type(table[0][column]).__name__ for column in table.columns}
+ for name, table in tables.mapping.items()
+ }
+ schema = ensure_schema(schema)
+ if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
+ raise ExecuteError("Tables must support the same table args as schema")
+ expression = maybe_parse(sql, dialect=read)
now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
@@ -34,6 +47,6 @@ def execute(sql, schema, read=None):
plan = Plan(expression)
logger.debug("Logical Plan: %s", plan)
now = time.time()
- result = PythonExecutor().execute(plan)
+ result = PythonExecutor(tables=tables).execute(plan)
logger.debug("Query finished: %f", time.time() - now)
return result
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 393347b..e9ff75b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -1,5 +1,12 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot.executor.env import ENV
+if t.TYPE_CHECKING:
+ from sqlglot.executor.table import Table, TableIter
+
class Context:
"""
@@ -12,14 +19,14 @@ class Context:
evaluation of aggregation functions.
"""
- def __init__(self, tables, env=None):
+ def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
"""
Args
- tables (dict): table_name -> Table, representing the scope of the current execution context
- env (Optional[dict]): dictionary of functions within the execution context
+ tables: representing the scope of the current execution context.
+ env: dictionary of functions within the execution context.
"""
self.tables = tables
- self._table = None
+ self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -31,7 +38,7 @@ class Context:
return tuple(self.eval(code) for code in codes)
@property
- def table(self):
+ def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
@@ -41,8 +48,12 @@ class Context:
raise Exception(f"Rows are different.")
return self._table
+ def add_columns(self, *columns: str) -> None:
+ for table in self.tables.values():
+ table.add_columns(*columns)
+
@property
- def columns(self):
+ def columns(self) -> t.Tuple:
return self.table.columns
def __iter__(self):
@@ -52,35 +63,39 @@ class Context:
reader = table[i]
yield reader, self
- def table_iter(self, table):
+ def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
- def sort(self, key):
- table = self.table
+ def filter(self, condition) -> None:
+ rows = [reader.row for reader, _ in self if self.eval(condition)]
- def sort_key(row):
- table.reader.row = row
+ for table in self.tables.values():
+ table.rows = rows
+
+ def sort(self, key) -> None:
+ def sort_key(row: t.Tuple) -> t.Tuple:
+ self.set_row(row)
return self.eval_tuple(key)
- table.rows.sort(key=sort_key)
+ self.table.rows.sort(key=sort_key)
- def set_row(self, row):
+ def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, index):
+ def set_index(self, index: int) -> None:
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
- def set_range(self, start, end):
+ def set_range(self, start: int, end: int) -> None:
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __contains__(self, table):
+ def __contains__(self, table: str) -> bool:
return table in self.tables
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index bbe6c81..ed80cc9 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -1,7 +1,10 @@
import datetime
+import inspect
import re
import statistics
+from functools import wraps
+from sqlglot import exp
from sqlglot.helper import PYTHON_VERSION
@@ -16,20 +19,153 @@ class reverse_key:
return other.obj < self.obj
+def filter_nulls(func):
+ @wraps(func)
+ def _func(values):
+ return func(v for v in values if v is not None)
+
+ return _func
+
+
+def null_if_any(*required):
+ """
+ Decorator that makes a function return `None` if any of the `required` arguments are `None`.
+
+ This also supports decoration with no arguments, e.g.:
+
+ @null_if_any
+ def foo(a, b): ...
+
+ In which case all arguments are required.
+ """
+ f = None
+ if len(required) == 1 and callable(required[0]):
+ f = required[0]
+ required = ()
+
+ def decorator(func):
+ if required:
+ required_indices = [
+ i for i, param in enumerate(inspect.signature(func).parameters) if param in required
+ ]
+
+ def predicate(*args):
+ return any(args[i] is None for i in required_indices)
+
+ else:
+
+ def predicate(*args):
+ return any(a is None for a in args)
+
+ @wraps(func)
+ def _func(*args):
+ if predicate(*args):
+ return None
+ return func(*args)
+
+ return _func
+
+ if f:
+ return decorator(f)
+
+ return decorator
+
+
+@null_if_any("substr", "this")
+def str_position(substr, this, position=None):
+ position = position - 1 if position is not None else position
+ return this.find(substr, position) + 1
+
+
+@null_if_any("this")
+def substring(this, start=None, length=None):
+ if start is None:
+ return this
+ elif start == 0:
+ return ""
+ elif start < 0:
+ start = len(this) + start
+ else:
+ start -= 1
+
+ end = None if length is None else start + length
+
+ return this[start:end]
+
+
+@null_if_any
+def cast(this, to):
+ if to == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(this)
+ if to == exp.DataType.Type.DATETIME:
+ return datetime.datetime.fromisoformat(this)
+ if to in exp.DataType.TEXT_TYPES:
+ return str(this)
+ if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
+ return float(this)
+ if to in exp.DataType.NUMERIC_TYPES:
+ return int(this)
+ raise NotImplementedError(f"Casting to '{to}' not implemented.")
+
+
+def ordered(this, desc, nulls_first):
+ if desc:
+ return reverse_key(this)
+ return this
+
+
+@null_if_any
+def interval(this, unit):
+ if unit == "DAY":
+ return datetime.timedelta(days=float(this))
+ raise NotImplementedError
+
+
ENV = {
"__builtins__": {},
- "datetime": datetime,
- "locals": locals,
- "re": re,
- "bool": bool,
- "float": float,
- "int": int,
- "str": str,
- "desc": reverse_key,
- "SUM": sum,
- "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
- "COUNT": lambda acc: sum(1 for e in acc if e is not None),
- "MAX": max,
- "MIN": min,
+ "exp": exp,
+ # aggs
+ "SUM": filter_nulls(sum),
+ "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
+ "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
+ "MAX": filter_nulls(max),
+ "MIN": filter_nulls(min),
+ # scalar functions
+ "ABS": null_if_any(lambda this: abs(this)),
+ "ADD": null_if_any(lambda e, this: e + this),
+ "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
+ "BITWISEAND": null_if_any(lambda this, e: this & e),
+ "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
+ "BITWISEOR": null_if_any(lambda this, e: this | e),
+ "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
+ "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
+ "CAST": cast,
+ "COALESCE": lambda *args: next((a for a in args if a is not None), None),
+ "CONCAT": null_if_any(lambda *args: "".join(args)),
+ "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DIV": null_if_any(lambda e, this: e / this),
+ "EQ": null_if_any(lambda this, e: this == e),
+ "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
+ "GT": null_if_any(lambda this, e: this > e),
+ "GTE": null_if_any(lambda this, e: this >= e),
+ "IFNULL": lambda e, alt: alt if e is None else e,
+ "IF": lambda predicate, true, false: true if predicate else false,
+ "INTDIV": null_if_any(lambda e, this: e // this),
+ "INTERVAL": interval,
+ "LIKE": null_if_any(
+ lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
+ ),
+ "LOWER": null_if_any(lambda arg: arg.lower()),
+ "LT": null_if_any(lambda this, e: this < e),
+ "LTE": null_if_any(lambda this, e: this <= e),
+ "MOD": null_if_any(lambda e, this: e % this),
+ "MUL": null_if_any(lambda e, this: e * this),
+ "NEQ": null_if_any(lambda this, e: this != e),
+ "ORD": null_if_any(ord),
+ "ORDERED": ordered,
"POW": pow,
+ "STRPOSITION": str_position,
+ "SUB": null_if_any(lambda e, this: e - this),
+ "SUBSTRING": substring,
+ "UPPER": null_if_any(lambda arg: arg.upper()),
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 7d1db32..cb2543c 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -5,16 +5,18 @@ import math
from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
+from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
-from sqlglot.executor.table import Table
-from sqlglot.helper import csv_reader
+from sqlglot.executor.table import RowReader, Table
+from sqlglot.helper import csv_reader, subclasses
class PythonExecutor:
- def __init__(self, env=None):
- self.generator = Python().generator(identify=True)
+ def __init__(self, env=None, tables=None):
+ self.generator = Python().generator(identify=True, comments=False)
self.env = {**ENV, **(env or {})}
+ self.tables = tables or {}
def execute(self, plan):
running = set()
@@ -24,36 +26,41 @@ class PythonExecutor:
while queue:
node = queue.pop()
- context = self.context(
- {
- name: table
- for dep in node.dependencies
- for name, table in contexts[dep].tables.items()
- }
- )
- running.add(node)
-
- if isinstance(node, planner.Scan):
- contexts[node] = self.scan(node, context)
- elif isinstance(node, planner.Aggregate):
- contexts[node] = self.aggregate(node, context)
- elif isinstance(node, planner.Join):
- contexts[node] = self.join(node, context)
- elif isinstance(node, planner.Sort):
- contexts[node] = self.sort(node, context)
- else:
- raise NotImplementedError
-
- running.remove(node)
- finished.add(node)
-
- for dep in node.dependents:
- if dep not in running and all(d in contexts for d in dep.dependencies):
- queue.add(dep)
-
- for dep in node.dependencies:
- if all(d in finished for d in dep.dependents):
- contexts.pop(dep)
+ try:
+ context = self.context(
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
+ )
+ running.add(node)
+
+ if isinstance(node, planner.Scan):
+ contexts[node] = self.scan(node, context)
+ elif isinstance(node, planner.Aggregate):
+ contexts[node] = self.aggregate(node, context)
+ elif isinstance(node, planner.Join):
+ contexts[node] = self.join(node, context)
+ elif isinstance(node, planner.Sort):
+ contexts[node] = self.sort(node, context)
+ elif isinstance(node, planner.SetOperation):
+ contexts[node] = self.set_operation(node, context)
+ else:
+ raise NotImplementedError
+
+ running.remove(node)
+ finished.add(node)
+
+ for dep in node.dependents:
+ if dep not in running and all(d in contexts for d in dep.dependencies):
+ queue.add(dep)
+
+ for dep in node.dependencies:
+ if all(d in finished for d in dep.dependents):
+ contexts.pop(dep)
+ except Exception as e:
+ raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
root = plan.root
return contexts[root].tables[root.name]
@@ -76,38 +83,43 @@ class PythonExecutor:
return Context(tables, env=self.env)
def table(self, expressions):
- return Table(expression.alias_or_name for expression in expressions)
+ return Table(
+ expression.alias_or_name if isinstance(expression, exp.Expression) else expression
+ for expression in expressions
+ )
def scan(self, step, context):
source = step.source
- if isinstance(source, exp.Expression):
+ if source and isinstance(source, exp.Expression):
source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
- if source in context:
+ if source is None:
+ context, table_iter = self.static()
+ elif source in context:
if not projections and not condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
- else:
+ elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
table_iter = self.scan_csv(step)
+ context = next(table_iter)
+ else:
+ context, table_iter = self.scan_table(step)
if projections:
sink = self.table(step.projections)
else:
- sink = None
-
- for reader, ctx in table_iter:
- if sink is None:
- sink = Table(reader.columns)
+ sink = self.table(context.columns)
- if condition and not ctx.eval(condition):
+ for reader in table_iter:
+ if condition and not context.eval(condition):
continue
if projections:
- sink.append(ctx.eval_tuple(projections))
+ sink.append(context.eval_tuple(projections))
else:
sink.append(reader.row)
@@ -116,14 +128,23 @@ class PythonExecutor:
return self.context({step.name: sink})
+ def static(self):
+ return self.context({}), [RowReader(())]
+
+ def scan_table(self, step):
+ table = self.tables.find(step.source)
+ context = self.context({step.source.alias_or_name: table})
+ return context, iter(table)
+
def scan_csv(self, step):
- source = step.source
- alias = source.alias
+ alias = step.source.alias
+ source = step.source.this
with csv_reader(source) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
+ yield context
types = []
for row in reader:
@@ -134,7 +155,7 @@ class PythonExecutor:
except (ValueError, SyntaxError):
types.append(str)
context.set_row(tuple(t(v) for t, v in zip(types, row)))
- yield context.table.reader, context
+ yield context.table.reader
def join(self, step, context):
source = step.name
@@ -160,16 +181,19 @@ class PythonExecutor:
for name, column_range in column_ranges.items()
}
)
+ condition = self.generate(join["condition"])
+ if condition:
+ source_context.filter(condition)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
- if not condition or not projections:
+ if not condition and not projections:
return source_context
sink = self.table(step.projections if projections else source_context.columns)
- for reader, ctx in join_context:
+ for reader, ctx in source_context:
if condition and not ctx.eval(condition):
continue
@@ -181,7 +205,15 @@ class PythonExecutor:
if len(sink) >= step.limit:
break
- return self.context({step.name: sink})
+ if projections:
+ return self.context({step.name: sink})
+ else:
+ return self.context(
+ {
+ name: Table(table.columns, sink.rows, table.column_range)
+ for name, table in source_context.tables.items()
+ }
+ )
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)
@@ -195,6 +227,8 @@ class PythonExecutor:
def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"])
+ left = join.get("side") == "LEFT"
+ right = join.get("side") == "RIGHT"
results = collections.defaultdict(lambda: ([], []))
@@ -204,28 +238,47 @@ class PythonExecutor:
results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns)
+ nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
for a_group, b_group in results.values():
+ if left:
+ b_group = b_group or nulls
+ elif right:
+ a_group = a_group or nulls
+
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context):
- source = step.source
- group_by = self.generate_tuple(step.group)
+ group_by = self.generate_tuple(step.group.values())
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
if operands:
- source_table = context.tables[source]
- operand_table = Table(source_table.columns + self.table(step.operands).columns)
+ operand_table = Table(self.table(step.operands).columns)
for reader, ctx in context:
- operand_table.append(reader.row + ctx.eval_tuple(operands))
+ operand_table.append(ctx.eval_tuple(operands))
+
+ for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
+ context.table.rows[i] = a + b
+
+ width = len(context.columns)
+ context.add_columns(*operand_table.columns)
+
+ operand_table = Table(
+ context.columns,
+ context.table.rows,
+ range(width, width + len(operand_table.columns)),
+ )
context = self.context(
- {None: operand_table, **{table: operand_table for table in context.tables}}
+ {
+ None: operand_table,
+ **context.tables,
+ }
)
context.sort(group_by)
@@ -233,25 +286,22 @@ class PythonExecutor:
group = None
start = 0
end = 1
- length = len(context.tables[source])
- table = self.table(step.group + step.aggregations)
+ length = len(context.table)
+ table = self.table(list(step.group) + step.aggregations)
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
-
+ if key != group:
+ context.set_range(start, end - 2)
+ table.append(group + context.eval_tuple(aggregations))
+ group = key
+ start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
- elif key != group:
- context.set_range(start, end - 2)
- else:
- continue
-
- table.append(group + context.eval_tuple(aggregations))
- group = key
- start = end - 2
+ table.append(group + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
@@ -262,60 +312,77 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
- sink = self.table(step.projections)
+ projection_columns = [p.alias_or_name for p in step.projections]
+ all_columns = list(context.columns) + projection_columns
+ sink = self.table(all_columns)
for reader, ctx in context:
- sink.append(ctx.eval_tuple(projections))
+ sink.append(reader.row + ctx.eval_tuple(projections))
- context = self.context(
+ sort_ctx = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
- context.sort(self.generate_tuple(step.key))
+ sort_ctx.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit):
- context.table.rows = context.table.rows[0 : step.limit]
+ sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
- return self.context({step.name: context.table})
+ output = Table(
+ projection_columns,
+ rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
+ )
+ return self.context({step.name: output})
+ def set_operation(self, step, context):
+ left = context.tables[step.left]
+ right = context.tables[step.right]
-def _cast_py(self, expression):
- to = expression.args["to"].this
- this = self.sql(expression, "this")
+ sink = self.table(left.columns)
+
+ if issubclass(step.op, exp.Intersect):
+ sink.rows = list(set(left.rows).intersection(set(right.rows)))
+ elif issubclass(step.op, exp.Except):
+ sink.rows = list(set(left.rows).difference(set(right.rows)))
+ elif issubclass(step.op, exp.Union) and step.distinct:
+ sink.rows = list(set(left.rows).union(set(right.rows)))
+ else:
+ sink.rows = left.rows + right.rows
- if to == exp.DataType.Type.DATE:
- return f"datetime.date.fromisoformat({this})"
- if to == exp.DataType.Type.TEXT:
- return f"str({this})"
- raise NotImplementedError
+ return self.context({step.name: sink})
-def _column_py(self, expression):
- table = self.sql(expression, "table") or None
+def _ordered_py(self, expression):
this = self.sql(expression, "this")
- return f"scope[{table}][{this}]"
+ desc = "True" if expression.args.get("desc") else "False"
+ nulls_first = "True" if expression.args.get("nulls_first") else "False"
+ return f"ORDERED({this}, {desc}, {nulls_first})"
-def _interval_py(self, expression):
- this = self.sql(expression, "this")
- unit = expression.text("unit").upper()
- if unit == "DAY":
- return f"datetime.timedelta(days=float({this}))"
- raise NotImplementedError
+def _rename(self, e):
+ try:
+ if "expressions" in e.args:
+ this = self.sql(e, "this")
+ this = f"{this}, " if this else ""
+ return f"{e.key.upper()}({this}{self.expressions(e)})"
+ return f"{e.key.upper()}({self.format_args(*e.args.values())})"
+ except Exception as ex:
+ raise Exception(f"Could not rename {repr(e)}") from ex
-def _like_py(self, expression):
+def _case_sql(self, expression):
this = self.sql(expression, "this")
- expression = self.sql(expression, "expression")
- return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
+ chain = self.sql(expression, "default") or "None"
+ for e in reversed(expression.args["ifs"]):
+ true = self.sql(e, "true")
+ condition = self.sql(e, "this")
+ condition = f"{this} = ({condition})" if this else condition
+ chain = f"{true} if {condition} else ({chain})"
-def _ordered_py(self, expression):
- this = self.sql(expression, "this")
- desc = expression.args.get("desc")
- return f"desc({this})" if desc else this
+ return chain
class Python(Dialect):
@@ -324,32 +391,22 @@ class Python(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
+ **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
+ **{klass: _rename for klass in exp.ALL_FUNCTIONS},
+ exp.Case: _case_sql,
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
+ exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
- exp.Cast: _cast_py,
- exp.Column: _column_py,
- exp.EQ: lambda self, e: self.binary(e, "=="),
+ exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
+ exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
+ exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
- exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
- exp.Like: _like_py,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}
-
- def case_sql(self, expression):
- this = self.sql(expression, "this")
- chain = self.sql(expression, "default") or "None"
-
- for e in reversed(expression.args["ifs"]):
- true = self.sql(e, "true")
- condition = self.sql(e, "this")
- condition = f"{this} = ({condition})" if this else condition
- chain = f"{true} if {condition} else ({chain})"
-
- return chain
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 6796740..f1b5b54 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,14 +1,27 @@
+from __future__ import annotations
+
+from sqlglot.helper import dict_depth
+from sqlglot.schema import AbstractMappingSchema
+
+
class Table:
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
-
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.range_reader = RangeReader(self)
+ def add_columns(self, *columns: str) -> None:
+ self.columns += columns
+ if self.column_range:
+ self.column_range = range(
+ self.column_range.start, self.column_range.stop + len(columns)
+ )
+ self.reader = RowReader(self.columns, self.column_range)
+
def append(self, row):
assert len(row) == len(self.columns)
self.rows.append(row)
@@ -87,3 +100,31 @@ class RowReader:
def __getitem__(self, column):
return self.row[self.columns[column]]
+
+
+class Tables(AbstractMappingSchema[Table]):
+ pass
+
+
+def ensure_tables(d: dict | None) -> Tables:
+ return Tables(_ensure_tables(d))
+
+
+def _ensure_tables(d: dict | None) -> dict:
+ if not d:
+ return {}
+
+ depth = dict_depth(d)
+
+ if depth > 1:
+ return {k: _ensure_tables(v) for k, v in d.items()}
+
+ result = {}
+ for name, table in d.items():
+ if isinstance(table, Table):
+ result[name] = table
+ else:
+ columns = tuple(table[0]) if table else ()
+ rows = [tuple(row[c] for c in columns) for row in table]
+ result[name] = Table(columns=columns, rows=rows)
+ return result