From f2981e8e4d28233864f1ca06ecec45ab80bf9eae Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 19 Nov 2022 15:50:39 +0100 Subject: Merging upstream version 10.0.8. Signed-off-by: Daniel Baumann --- sqlglot/executor/env.py | 162 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 149 insertions(+), 13 deletions(-) (limited to 'sqlglot/executor/env.py') diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index bbe6c81..ed80cc9 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -1,7 +1,10 @@ import datetime +import inspect import re import statistics +from functools import wraps +from sqlglot import exp from sqlglot.helper import PYTHON_VERSION @@ -16,20 +19,153 @@ class reverse_key: return other.obj < self.obj +def filter_nulls(func): + @wraps(func) + def _func(values): + return func(v for v in values if v is not None) + + return _func + + +def null_if_any(*required): + """ + Decorator that makes a function return `None` if any of the `required` arguments are `None`. + + This also supports decoration with no arguments, e.g.: + + @null_if_any + def foo(a, b): ... + + In which case all arguments are required. + """ + f = None + if len(required) == 1 and callable(required[0]): + f = required[0] + required = () + + def decorator(func): + if required: + required_indices = [ + i for i, param in enumerate(inspect.signature(func).parameters) if param in required + ] + + def predicate(*args): + return any(args[i] is None for i in required_indices) + + else: + + def predicate(*args): + return any(a is None for a in args) + + @wraps(func) + def _func(*args): + if predicate(*args): + return None + return func(*args) + + return _func + + if f: + return decorator(f) + + return decorator + + +@null_if_any("substr", "this") +def str_position(substr, this, position=None): + position = position - 1 if position is not None else position + return this.find(substr, position) + 1 + + +@null_if_any("this") +def substring(this, start=None, length=None): + if start is None: + return this + elif start == 0: + return "" + elif start < 0: + start = len(this) + start + else: + start -= 1 + + end = None if length is None else start + length + + return this[start:end] + + +@null_if_any +def cast(this, to): + if to == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(this) + if to == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(this) + if to in exp.DataType.TEXT_TYPES: + return str(this) + if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: + return float(this) + if to in exp.DataType.NUMERIC_TYPES: + return int(this) + raise NotImplementedError(f"Casting to '{to}' not implemented.") + + +def ordered(this, desc, nulls_first): + if desc: + return reverse_key(this) + return this + + +@null_if_any +def interval(this, unit): + if unit == "DAY": + return datetime.timedelta(days=float(this)) + raise NotImplementedError + + ENV = { "__builtins__": {}, - "datetime": datetime, - "locals": locals, - "re": re, - "bool": bool, - "float": float, - "int": int, - "str": str, - "desc": reverse_key, - "SUM": sum, - "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore - "COUNT": lambda acc: sum(1 for e in acc if e is not None), - "MAX": max, - "MIN": min, + "exp": exp, + # aggs + "SUM": filter_nulls(sum), + "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore + "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)), + "MAX": filter_nulls(max), + "MIN": filter_nulls(min), + # scalar functions + "ABS": null_if_any(lambda this: abs(this)), + "ADD": null_if_any(lambda e, this: e + this), + "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), + "BITWISEAND": null_if_any(lambda this, e: this & e), + "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), + "BITWISEOR": null_if_any(lambda this, e: this | e), + "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e), + "BITWISEXOR": null_if_any(lambda this, e: this ^ e), + "CAST": cast, + "COALESCE": lambda *args: next((a for a in args if a is not None), None), + "CONCAT": null_if_any(lambda *args: "".join(args)), + "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DIV": null_if_any(lambda e, this: e / this), + "EQ": null_if_any(lambda this, e: this == e), + "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), + "GT": null_if_any(lambda this, e: this > e), + "GTE": null_if_any(lambda this, e: this >= e), + "IFNULL": lambda e, alt: alt if e is None else e, + "IF": lambda predicate, true, false: true if predicate else false, + "INTDIV": null_if_any(lambda e, this: e // this), + "INTERVAL": interval, + "LIKE": null_if_any( + lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this)) + ), + "LOWER": null_if_any(lambda arg: arg.lower()), + "LT": null_if_any(lambda this, e: this < e), + "LTE": null_if_any(lambda this, e: this <= e), + "MOD": null_if_any(lambda e, this: e % this), + "MUL": null_if_any(lambda e, this: e * this), + "NEQ": null_if_any(lambda this, e: this != e), + "ORD": null_if_any(ord), + "ORDERED": ordered, "POW": pow, + "STRPOSITION": str_position, + "SUB": null_if_any(lambda e, this: e - this), + "SUBSTRING": substring, + "UPPER": null_if_any(lambda arg: arg.upper()), } -- cgit v1.2.3