diff options
Diffstat (limited to 'sqlglot/executor')
-rw-r--r-- | sqlglot/executor/__init__.py | 16 | ||||
-rw-r--r-- | sqlglot/executor/env.py | 15 | ||||
-rw-r--r-- | sqlglot/executor/python.py | 19 |
3 files changed, 37 insertions, 13 deletions
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index a67c155..017d5bc 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -14,9 +14,10 @@ from sqlglot import maybe_parse from sqlglot.errors import ExecuteError from sqlglot.executor.python import PythonExecutor from sqlglot.executor.table import Table, ensure_tables +from sqlglot.helper import dict_depth from sqlglot.optimizer import optimize from sqlglot.planner import Plan -from sqlglot.schema import ensure_schema +from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set logger = logging.getLogger("sqlglot") @@ -52,10 +53,15 @@ def execute( tables_ = ensure_tables(tables) if not schema: - schema = { - name: {column: type(table[0][column]).__name__ for column in table.columns} - for name, table in tables_.mapping.items() - } + schema = {} + flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping)) + + for keys in flattened_tables: + table = nested_get(tables_.mapping, *zip(keys, keys)) + assert table is not None + + for column in table.columns: + nested_set(schema, [*keys, column], type(table[0][column]).__name__) schema = ensure_schema(schema, dialect=read) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 8f64cce..51cffbd 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -5,6 +5,7 @@ import statistics from functools import wraps from sqlglot import exp +from sqlglot.generator import Generator from sqlglot.helper import PYTHON_VERSION @@ -102,6 +103,8 @@ def cast(this, to): return datetime.date.fromisoformat(this) if to == exp.DataType.Type.DATETIME: return datetime.datetime.fromisoformat(this) + if to == exp.DataType.Type.BOOLEAN: + return bool(this) if to in exp.DataType.TEXT_TYPES: return str(this) if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: @@ -119,9 +122,11 @@ def ordered(this, desc, nulls_first): @null_if_any def interval(this, unit): - if unit == "DAY": - return datetime.timedelta(days=float(this)) - raise NotImplementedError + unit = unit.lower() + plural = unit + "s" + if plural in Generator.TIME_PART_SINGULARS: + unit = plural + return datetime.timedelta(**{unit: float(this)}) ENV = { @@ -147,7 +152,9 @@ ENV = { "COALESCE": lambda *args: next((a for a in args if a is not None), None), "CONCAT": null_if_any(lambda *args: "".join(args)), "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)), "DIV": null_if_any(lambda e, this: e / this), + "DOT": null_if_any(lambda e, this: e[this]), "EQ": null_if_any(lambda this, e: this == e), "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), "GT": null_if_any(lambda this, e: this > e), @@ -162,6 +169,7 @@ ENV = { "LOWER": null_if_any(lambda arg: arg.lower()), "LT": null_if_any(lambda this, e: this < e), "LTE": null_if_any(lambda this, e: this <= e), + "MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore "MOD": null_if_any(lambda e, this: e % this), "MUL": null_if_any(lambda e, this: e * this), "NEQ": null_if_any(lambda this, e: this != e), @@ -180,4 +188,5 @@ ENV = { "CURRENTTIMESTAMP": datetime.datetime.now, "CURRENTTIME": datetime.datetime.now, "CURRENTDATE": datetime.date.today, + "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)), } diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index b71cc6a..f114e5c 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -360,11 +360,19 @@ def _ordered_py(self, expression): def _rename(self, e): try: - if "expressions" in e.args: - this = self.sql(e, "this") - this = f"{this}, " if this else "" - return f"{e.key.upper()}({this}{self.expressions(e)})" - return self.func(e.key, *e.args.values()) + values = list(e.args.values()) + + if len(values) == 1: + values = values[0] + if not isinstance(values, list): + return self.func(e.key, values) + return self.func(e.key, *values) + + if isinstance(e, exp.Func) and e.is_var_len_args: + *head, tail = values + return self.func(e.key, *head, *tail) + + return self.func(e.key, *values) except Exception as ex: raise Exception(f"Could not rename {repr(e)}") from ex @@ -413,6 +421,7 @@ class Python(Dialect): exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", + exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", exp.Is: lambda self, e: self.binary(e, "is"), exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", |