summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor/env.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/executor/env.py')
-rw-r--r--sqlglot/executor/env.py162
1 files changed, 149 insertions, 13 deletions
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index bbe6c81..ed80cc9 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -1,7 +1,10 @@
import datetime
+import inspect
import re
import statistics
+from functools import wraps
+from sqlglot import exp
from sqlglot.helper import PYTHON_VERSION
@@ -16,20 +19,153 @@ class reverse_key:
return other.obj < self.obj
+def filter_nulls(func):
+ @wraps(func)
+ def _func(values):
+ return func(v for v in values if v is not None)
+
+ return _func
+
+
+def null_if_any(*required):
+ """
+ Decorator that makes a function return `None` if any of the `required` arguments are `None`.
+
+ This also supports decoration with no arguments, e.g.:
+
+ @null_if_any
+ def foo(a, b): ...
+
+ In which case all arguments are required.
+ """
+ f = None
+ if len(required) == 1 and callable(required[0]):
+ f = required[0]
+ required = ()
+
+ def decorator(func):
+ if required:
+ required_indices = [
+ i for i, param in enumerate(inspect.signature(func).parameters) if param in required
+ ]
+
+ def predicate(*args):
+ return any(args[i] is None for i in required_indices)
+
+ else:
+
+ def predicate(*args):
+ return any(a is None for a in args)
+
+ @wraps(func)
+ def _func(*args):
+ if predicate(*args):
+ return None
+ return func(*args)
+
+ return _func
+
+ if f:
+ return decorator(f)
+
+ return decorator
+
+
+@null_if_any("substr", "this")
+def str_position(substr, this, position=None):
+ position = position - 1 if position is not None else position
+ return this.find(substr, position) + 1
+
+
+@null_if_any("this")
+def substring(this, start=None, length=None):
+ if start is None:
+ return this
+ elif start == 0:
+ return ""
+ elif start < 0:
+ start = len(this) + start
+ else:
+ start -= 1
+
+ end = None if length is None else start + length
+
+ return this[start:end]
+
+
+@null_if_any
+def cast(this, to):
+ if to == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(this)
+ if to == exp.DataType.Type.DATETIME:
+ return datetime.datetime.fromisoformat(this)
+ if to in exp.DataType.TEXT_TYPES:
+ return str(this)
+ if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
+ return float(this)
+ if to in exp.DataType.NUMERIC_TYPES:
+ return int(this)
+ raise NotImplementedError(f"Casting to '{to}' not implemented.")
+
+
+def ordered(this, desc, nulls_first):
+ if desc:
+ return reverse_key(this)
+ return this
+
+
+@null_if_any
+def interval(this, unit):
+ if unit == "DAY":
+ return datetime.timedelta(days=float(this))
+ raise NotImplementedError
+
+
ENV = {
"__builtins__": {},
- "datetime": datetime,
- "locals": locals,
- "re": re,
- "bool": bool,
- "float": float,
- "int": int,
- "str": str,
- "desc": reverse_key,
- "SUM": sum,
- "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
- "COUNT": lambda acc: sum(1 for e in acc if e is not None),
- "MAX": max,
- "MIN": min,
+ "exp": exp,
+ # aggs
+ "SUM": filter_nulls(sum),
+ "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
+ "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
+ "MAX": filter_nulls(max),
+ "MIN": filter_nulls(min),
+ # scalar functions
+ "ABS": null_if_any(lambda this: abs(this)),
+ "ADD": null_if_any(lambda e, this: e + this),
+ "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
+ "BITWISEAND": null_if_any(lambda this, e: this & e),
+ "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
+ "BITWISEOR": null_if_any(lambda this, e: this | e),
+ "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
+ "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
+ "CAST": cast,
+ "COALESCE": lambda *args: next((a for a in args if a is not None), None),
+ "CONCAT": null_if_any(lambda *args: "".join(args)),
+ "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DIV": null_if_any(lambda e, this: e / this),
+ "EQ": null_if_any(lambda this, e: this == e),
+ "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
+ "GT": null_if_any(lambda this, e: this > e),
+ "GTE": null_if_any(lambda this, e: this >= e),
+ "IFNULL": lambda e, alt: alt if e is None else e,
+ "IF": lambda predicate, true, false: true if predicate else false,
+ "INTDIV": null_if_any(lambda e, this: e // this),
+ "INTERVAL": interval,
+ "LIKE": null_if_any(
+ lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
+ ),
+ "LOWER": null_if_any(lambda arg: arg.lower()),
+ "LT": null_if_any(lambda this, e: this < e),
+ "LTE": null_if_any(lambda this, e: this <= e),
+ "MOD": null_if_any(lambda e, this: e % this),
+ "MUL": null_if_any(lambda e, this: e * this),
+ "NEQ": null_if_any(lambda this, e: this != e),
+ "ORD": null_if_any(ord),
+ "ORDERED": ordered,
"POW": pow,
+ "STRPOSITION": str_position,
+ "SUB": null_if_any(lambda e, this: e - this),
+ "SUBSTRING": substring,
+ "UPPER": null_if_any(lambda arg: arg.upper()),
}