summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor')
-rw-r--r--sqlglot/executor/__init__.py16
-rw-r--r--sqlglot/executor/env.py15
-rw-r--r--sqlglot/executor/python.py19
3 files changed, 37 insertions, 13 deletions
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index a67c155..017d5bc 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -14,9 +14,10 @@ from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
+from sqlglot.helper import dict_depth
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
-from sqlglot.schema import ensure_schema
+from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set
logger = logging.getLogger("sqlglot")
@@ -52,10 +53,15 @@ def execute(
tables_ = ensure_tables(tables)
if not schema:
- schema = {
- name: {column: type(table[0][column]).__name__ for column in table.columns}
- for name, table in tables_.mapping.items()
- }
+ schema = {}
+ flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping))
+
+ for keys in flattened_tables:
+ table = nested_get(tables_.mapping, *zip(keys, keys))
+ assert table is not None
+
+ for column in table.columns:
+ nested_set(schema, [*keys, column], type(table[0][column]).__name__)
schema = ensure_schema(schema, dialect=read)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 8f64cce..51cffbd 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -5,6 +5,7 @@ import statistics
from functools import wraps
from sqlglot import exp
+from sqlglot.generator import Generator
from sqlglot.helper import PYTHON_VERSION
@@ -102,6 +103,8 @@ def cast(this, to):
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(this)
+ if to == exp.DataType.Type.BOOLEAN:
+ return bool(this)
if to in exp.DataType.TEXT_TYPES:
return str(this)
if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
@@ -119,9 +122,11 @@ def ordered(this, desc, nulls_first):
@null_if_any
def interval(this, unit):
- if unit == "DAY":
- return datetime.timedelta(days=float(this))
- raise NotImplementedError
+ unit = unit.lower()
+ plural = unit + "s"
+ if plural in Generator.TIME_PART_SINGULARS:
+ unit = plural
+ return datetime.timedelta(**{unit: float(this)})
ENV = {
@@ -147,7 +152,9 @@ ENV = {
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
+ "DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
@@ -162,6 +169,7 @@ ENV = {
"LOWER": null_if_any(lambda arg: arg.lower()),
"LT": null_if_any(lambda this, e: this < e),
"LTE": null_if_any(lambda this, e: this <= e),
+ "MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore
"MOD": null_if_any(lambda e, this: e % this),
"MUL": null_if_any(lambda e, this: e * this),
"NEQ": null_if_any(lambda this, e: this != e),
@@ -180,4 +188,5 @@ ENV = {
"CURRENTTIMESTAMP": datetime.datetime.now,
"CURRENTTIME": datetime.datetime.now,
"CURRENTDATE": datetime.date.today,
+ "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index b71cc6a..f114e5c 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -360,11 +360,19 @@ def _ordered_py(self, expression):
def _rename(self, e):
try:
- if "expressions" in e.args:
- this = self.sql(e, "this")
- this = f"{this}, " if this else ""
- return f"{e.key.upper()}({this}{self.expressions(e)})"
- return self.func(e.key, *e.args.values())
+ values = list(e.args.values())
+
+ if len(values) == 1:
+ values = values[0]
+ if not isinstance(values, list):
+ return self.func(e.key, values)
+ return self.func(e.key, *values)
+
+ if isinstance(e, exp.Func) and e.is_var_len_args:
+ *head, tail = values
+ return self.func(e.key, *head, *tail)
+
+ return self.func(e.key, *values)
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex
@@ -413,6 +421,7 @@ class Python(Dialect):
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
+ exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",