Edit on GitHub

sqlglot.executor.env

  1import datetime
  2import inspect
  3import re
  4import statistics
  5from functools import wraps
  6
  7from sqlglot import exp
  8from sqlglot.helper import PYTHON_VERSION
  9
 10
 11class reverse_key:
 12    def __init__(self, obj):
 13        self.obj = obj
 14
 15    def __eq__(self, other):
 16        return other.obj == self.obj
 17
 18    def __lt__(self, other):
 19        return other.obj < self.obj
 20
 21
 22def filter_nulls(func, empty_null=True):
 23    @wraps(func)
 24    def _func(values):
 25        filtered = tuple(v for v in values if v is not None)
 26        if not filtered and empty_null:
 27            return None
 28        return func(filtered)
 29
 30    return _func
 31
 32
 33def null_if_any(*required):
 34    """
 35    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
 36
 37    This also supports decoration with no arguments, e.g.:
 38
 39        @null_if_any
 40        def foo(a, b): ...
 41
 42    In which case all arguments are required.
 43    """
 44    f = None
 45    if len(required) == 1 and callable(required[0]):
 46        f = required[0]
 47        required = ()
 48
 49    def decorator(func):
 50        if required:
 51            required_indices = [
 52                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
 53            ]
 54
 55            def predicate(*args):
 56                return any(args[i] is None for i in required_indices)
 57
 58        else:
 59
 60            def predicate(*args):
 61                return any(a is None for a in args)
 62
 63        @wraps(func)
 64        def _func(*args):
 65            if predicate(*args):
 66                return None
 67            return func(*args)
 68
 69        return _func
 70
 71    if f:
 72        return decorator(f)
 73
 74    return decorator
 75
 76
 77@null_if_any("substr", "this")
 78def str_position(substr, this, position=None):
 79    position = position - 1 if position is not None else position
 80    return this.find(substr, position) + 1
 81
 82
 83@null_if_any("this")
 84def substring(this, start=None, length=None):
 85    if start is None:
 86        return this
 87    elif start == 0:
 88        return ""
 89    elif start < 0:
 90        start = len(this) + start
 91    else:
 92        start -= 1
 93
 94    end = None if length is None else start + length
 95
 96    return this[start:end]
 97
 98
 99@null_if_any
100def cast(this, to):
101    if to == exp.DataType.Type.DATE:
102        return datetime.date.fromisoformat(this)
103    if to == exp.DataType.Type.DATETIME:
104        return datetime.datetime.fromisoformat(this)
105    if to in exp.DataType.TEXT_TYPES:
106        return str(this)
107    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
108        return float(this)
109    if to in exp.DataType.NUMERIC_TYPES:
110        return int(this)
111    raise NotImplementedError(f"Casting to '{to}' not implemented.")
112
113
114def ordered(this, desc, nulls_first):
115    if desc:
116        return reverse_key(this)
117    return this
118
119
120@null_if_any
121def interval(this, unit):
122    if unit == "DAY":
123        return datetime.timedelta(days=float(this))
124    raise NotImplementedError
125
126
127ENV = {
128    "exp": exp,
129    # aggs
130    "ARRAYAGG": list,
131    "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean),  # type: ignore
132    "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
133    "MAX": filter_nulls(max),
134    "MIN": filter_nulls(min),
135    "SUM": filter_nulls(sum),
136    # scalar functions
137    "ABS": null_if_any(lambda this: abs(this)),
138    "ADD": null_if_any(lambda e, this: e + this),
139    "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
140    "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
141    "BITWISEAND": null_if_any(lambda this, e: this & e),
142    "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
143    "BITWISEOR": null_if_any(lambda this, e: this | e),
144    "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
145    "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
146    "CAST": cast,
147    "COALESCE": lambda *args: next((a for a in args if a is not None), None),
148    "CONCAT": null_if_any(lambda *args: "".join(args)),
149    "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
150    "DIV": null_if_any(lambda e, this: e / this),
151    "EQ": null_if_any(lambda this, e: this == e),
152    "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
153    "GT": null_if_any(lambda this, e: this > e),
154    "GTE": null_if_any(lambda this, e: this >= e),
155    "IFNULL": lambda e, alt: alt if e is None else e,
156    "IF": lambda predicate, true, false: true if predicate else false,
157    "INTDIV": null_if_any(lambda e, this: e // this),
158    "INTERVAL": interval,
159    "LIKE": null_if_any(
160        lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
161    ),
162    "LOWER": null_if_any(lambda arg: arg.lower()),
163    "LT": null_if_any(lambda this, e: this < e),
164    "LTE": null_if_any(lambda this, e: this <= e),
165    "MOD": null_if_any(lambda e, this: e % this),
166    "MUL": null_if_any(lambda e, this: e * this),
167    "NEQ": null_if_any(lambda this, e: this != e),
168    "ORD": null_if_any(ord),
169    "ORDERED": ordered,
170    "POW": pow,
171    "STRPOSITION": str_position,
172    "SUB": null_if_any(lambda e, this: e - this),
173    "SUBSTRING": substring,
174    "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
175    "UPPER": null_if_any(lambda arg: arg.upper()),
176}
class reverse_key:
12class reverse_key:
13    def __init__(self, obj):
14        self.obj = obj
15
16    def __eq__(self, other):
17        return other.obj == self.obj
18
19    def __lt__(self, other):
20        return other.obj < self.obj
reverse_key(obj)
13    def __init__(self, obj):
14        self.obj = obj
def filter_nulls(func, empty_null=True):
23def filter_nulls(func, empty_null=True):
24    @wraps(func)
25    def _func(values):
26        filtered = tuple(v for v in values if v is not None)
27        if not filtered and empty_null:
28            return None
29        return func(filtered)
30
31    return _func
def null_if_any(*required):
34def null_if_any(*required):
35    """
36    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
37
38    This also supports decoration with no arguments, e.g.:
39
40        @null_if_any
41        def foo(a, b): ...
42
43    In which case all arguments are required.
44    """
45    f = None
46    if len(required) == 1 and callable(required[0]):
47        f = required[0]
48        required = ()
49
50    def decorator(func):
51        if required:
52            required_indices = [
53                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
54            ]
55
56            def predicate(*args):
57                return any(args[i] is None for i in required_indices)
58
59        else:
60
61            def predicate(*args):
62                return any(a is None for a in args)
63
64        @wraps(func)
65        def _func(*args):
66            if predicate(*args):
67                return None
68            return func(*args)
69
70        return _func
71
72    if f:
73        return decorator(f)
74
75    return decorator

Decorator that makes a function return None if any of the required arguments are None.

This also supports decoration with no arguments, e.g.:

@null_if_any
def foo(a, b): ...

In which case all arguments are required.

@null_if_any('substr', 'this')
def str_position(substr, this, position=None):
78@null_if_any("substr", "this")
79def str_position(substr, this, position=None):
80    position = position - 1 if position is not None else position
81    return this.find(substr, position) + 1
@null_if_any('this')
def substring(this, start=None, length=None):
84@null_if_any("this")
85def substring(this, start=None, length=None):
86    if start is None:
87        return this
88    elif start == 0:
89        return ""
90    elif start < 0:
91        start = len(this) + start
92    else:
93        start -= 1
94
95    end = None if length is None else start + length
96
97    return this[start:end]
@null_if_any
def cast(this, to):
100@null_if_any
101def cast(this, to):
102    if to == exp.DataType.Type.DATE:
103        return datetime.date.fromisoformat(this)
104    if to == exp.DataType.Type.DATETIME:
105        return datetime.datetime.fromisoformat(this)
106    if to in exp.DataType.TEXT_TYPES:
107        return str(this)
108    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
109        return float(this)
110    if to in exp.DataType.NUMERIC_TYPES:
111        return int(this)
112    raise NotImplementedError(f"Casting to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
115def ordered(this, desc, nulls_first):
116    if desc:
117        return reverse_key(this)
118    return this
@null_if_any
def interval(this, unit):
121@null_if_any
122def interval(this, unit):
123    if unit == "DAY":
124        return datetime.timedelta(days=float(this))
125    raise NotImplementedError