summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor')
-rw-r--r--sqlglot/executor/context.py4
-rw-r--r--sqlglot/executor/env.py4
-rw-r--r--sqlglot/executor/python.py13
3 files changed, 18 insertions, 3 deletions
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index e9ff75b..8a58287 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -29,10 +29,10 @@ class Context:
self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
- self.env = {**(env or {}), "scope": self.row_readers}
+ self.env = {**ENV, **(env or {}), "scope": self.row_readers}
def eval(self, code):
- return eval(code, ENV, self.env)
+ return eval(code, self.env)
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index ad9397e..04dc938 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -127,14 +127,16 @@ def interval(this, unit):
ENV = {
"exp": exp,
# aggs
- "SUM": filter_nulls(sum),
+ "ARRAYAGG": list,
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
+ "SUM": filter_nulls(sum),
# scalar functions
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
+ "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 9f22c45..29848c6 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -394,6 +394,18 @@ def _case_sql(self, expression):
return chain
+def _lambda_sql(self, e: exp.Lambda) -> str:
+ names = {e.name.lower() for e in e.expressions}
+
+ e = e.transform(
+ lambda n: exp.Var(this=n.name)
+ if isinstance(n, exp.Identifier) and n.name.lower() in names
+ else n
+ )
+
+ return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
+
+
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
@@ -414,6 +426,7 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
+ exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),