Edit on GitHub

sqlglot.executor.env

  1import datetime
  2import inspect
  3import re
  4import statistics
  5from functools import wraps
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import PYTHON_VERSION
 10
 11
 12class reverse_key:
 13    def __init__(self, obj):
 14        self.obj = obj
 15
 16    def __eq__(self, other):
 17        return other.obj == self.obj
 18
 19    def __lt__(self, other):
 20        return other.obj < self.obj
 21
 22
 23def filter_nulls(func, empty_null=True):
 24    @wraps(func)
 25    def _func(values):
 26        filtered = tuple(v for v in values if v is not None)
 27        if not filtered and empty_null:
 28            return None
 29        return func(filtered)
 30
 31    return _func
 32
 33
 34def null_if_any(*required):
 35    """
 36    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
 37
 38    This also supports decoration with no arguments, e.g.:
 39
 40        @null_if_any
 41        def foo(a, b): ...
 42
 43    In which case all arguments are required.
 44    """
 45    f = None
 46    if len(required) == 1 and callable(required[0]):
 47        f = required[0]
 48        required = ()
 49
 50    def decorator(func):
 51        if required:
 52            required_indices = [
 53                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
 54            ]
 55
 56            def predicate(*args):
 57                return any(args[i] is None for i in required_indices)
 58
 59        else:
 60
 61            def predicate(*args):
 62                return any(a is None for a in args)
 63
 64        @wraps(func)
 65        def _func(*args):
 66            if predicate(*args):
 67                return None
 68            return func(*args)
 69
 70        return _func
 71
 72    if f:
 73        return decorator(f)
 74
 75    return decorator
 76
 77
 78@null_if_any("substr", "this")
 79def str_position(substr, this, position=None):
 80    position = position - 1 if position is not None else position
 81    return this.find(substr, position) + 1
 82
 83
 84@null_if_any("this")
 85def substring(this, start=None, length=None):
 86    if start is None:
 87        return this
 88    elif start == 0:
 89        return ""
 90    elif start < 0:
 91        start = len(this) + start
 92    else:
 93        start -= 1
 94
 95    end = None if length is None else start + length
 96
 97    return this[start:end]
 98
 99
100@null_if_any
101def cast(this, to):
102    if to == exp.DataType.Type.DATE:
103        return datetime.date.fromisoformat(this)
104    if to == exp.DataType.Type.DATETIME:
105        return datetime.datetime.fromisoformat(this)
106    if to == exp.DataType.Type.BOOLEAN:
107        return bool(this)
108    if to in exp.DataType.TEXT_TYPES:
109        return str(this)
110    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
111        return float(this)
112    if to in exp.DataType.NUMERIC_TYPES:
113        return int(this)
114    raise NotImplementedError(f"Casting to '{to}' not implemented.")
115
116
117def ordered(this, desc, nulls_first):
118    if desc:
119        return reverse_key(this)
120    return this
121
122
123@null_if_any
124def interval(this, unit):
125    unit = unit.lower()
126    plural = unit + "s"
127    if plural in Generator.TIME_PART_SINGULARS:
128        unit = plural
129    return datetime.timedelta(**{unit: float(this)})
130
131
132ENV = {
133    "exp": exp,
134    # aggs
135    "ARRAYAGG": list,
136    "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean),  # type: ignore
137    "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
138    "MAX": filter_nulls(max),
139    "MIN": filter_nulls(min),
140    "SUM": filter_nulls(sum),
141    # scalar functions
142    "ABS": null_if_any(lambda this: abs(this)),
143    "ADD": null_if_any(lambda e, this: e + this),
144    "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
145    "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
146    "BITWISEAND": null_if_any(lambda this, e: this & e),
147    "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
148    "BITWISEOR": null_if_any(lambda this, e: this | e),
149    "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
150    "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
151    "CAST": cast,
152    "COALESCE": lambda *args: next((a for a in args if a is not None), None),
153    "CONCAT": null_if_any(lambda *args: "".join(args)),
154    "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
155    "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
156    "DIV": null_if_any(lambda e, this: e / this),
157    "DOT": null_if_any(lambda e, this: e[this]),
158    "EQ": null_if_any(lambda this, e: this == e),
159    "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
160    "GT": null_if_any(lambda this, e: this > e),
161    "GTE": null_if_any(lambda this, e: this >= e),
162    "IFNULL": lambda e, alt: alt if e is None else e,
163    "IF": lambda predicate, true, false: true if predicate else false,
164    "INTDIV": null_if_any(lambda e, this: e // this),
165    "INTERVAL": interval,
166    "LIKE": null_if_any(
167        lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
168    ),
169    "LOWER": null_if_any(lambda arg: arg.lower()),
170    "LT": null_if_any(lambda this, e: this < e),
171    "LTE": null_if_any(lambda this, e: this <= e),
172    "MAP": null_if_any(lambda *args: dict(zip(*args))),  # type: ignore
173    "MOD": null_if_any(lambda e, this: e % this),
174    "MUL": null_if_any(lambda e, this: e * this),
175    "NEQ": null_if_any(lambda this, e: this != e),
176    "ORD": null_if_any(ord),
177    "ORDERED": ordered,
178    "POW": pow,
179    "STRPOSITION": str_position,
180    "SUB": null_if_any(lambda e, this: e - this),
181    "SUBSTRING": substring,
182    "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
183    "UPPER": null_if_any(lambda arg: arg.upper()),
184    "YEAR": null_if_any(lambda arg: arg.year),
185    "MONTH": null_if_any(lambda arg: arg.month),
186    "DAY": null_if_any(lambda arg: arg.day),
187    "CURRENTDATETIME": datetime.datetime.now,
188    "CURRENTTIMESTAMP": datetime.datetime.now,
189    "CURRENTTIME": datetime.datetime.now,
190    "CURRENTDATE": datetime.date.today,
191    "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
192}
class reverse_key:
13class reverse_key:
14    def __init__(self, obj):
15        self.obj = obj
16
17    def __eq__(self, other):
18        return other.obj == self.obj
19
20    def __lt__(self, other):
21        return other.obj < self.obj
reverse_key(obj)
14    def __init__(self, obj):
15        self.obj = obj
def filter_nulls(func, empty_null=True):
24def filter_nulls(func, empty_null=True):
25    @wraps(func)
26    def _func(values):
27        filtered = tuple(v for v in values if v is not None)
28        if not filtered and empty_null:
29            return None
30        return func(filtered)
31
32    return _func
def null_if_any(*required):
35def null_if_any(*required):
36    """
37    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
38
39    This also supports decoration with no arguments, e.g.:
40
41        @null_if_any
42        def foo(a, b): ...
43
44    In which case all arguments are required.
45    """
46    f = None
47    if len(required) == 1 and callable(required[0]):
48        f = required[0]
49        required = ()
50
51    def decorator(func):
52        if required:
53            required_indices = [
54                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
55            ]
56
57            def predicate(*args):
58                return any(args[i] is None for i in required_indices)
59
60        else:
61
62            def predicate(*args):
63                return any(a is None for a in args)
64
65        @wraps(func)
66        def _func(*args):
67            if predicate(*args):
68                return None
69            return func(*args)
70
71        return _func
72
73    if f:
74        return decorator(f)
75
76    return decorator

Decorator that makes a function return None if any of the required arguments are None.

This also supports decoration with no arguments, e.g.:

@null_if_any
def foo(a, b): ...

In which case all arguments are required.

@null_if_any('substr', 'this')
def str_position(substr, this, position=None):
79@null_if_any("substr", "this")
80def str_position(substr, this, position=None):
81    position = position - 1 if position is not None else position
82    return this.find(substr, position) + 1
@null_if_any('this')
def substring(this, start=None, length=None):
85@null_if_any("this")
86def substring(this, start=None, length=None):
87    if start is None:
88        return this
89    elif start == 0:
90        return ""
91    elif start < 0:
92        start = len(this) + start
93    else:
94        start -= 1
95
96    end = None if length is None else start + length
97
98    return this[start:end]
@null_if_any
def cast(this, to):
101@null_if_any
102def cast(this, to):
103    if to == exp.DataType.Type.DATE:
104        return datetime.date.fromisoformat(this)
105    if to == exp.DataType.Type.DATETIME:
106        return datetime.datetime.fromisoformat(this)
107    if to == exp.DataType.Type.BOOLEAN:
108        return bool(this)
109    if to in exp.DataType.TEXT_TYPES:
110        return str(this)
111    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
112        return float(this)
113    if to in exp.DataType.NUMERIC_TYPES:
114        return int(this)
115    raise NotImplementedError(f"Casting to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
118def ordered(this, desc, nulls_first):
119    if desc:
120        return reverse_key(this)
121    return this
@null_if_any
def interval(this, unit):
124@null_if_any
125def interval(this, unit):
126    unit = unit.lower()
127    plural = unit + "s"
128    if plural in Generator.TIME_PART_SINGULARS:
129        unit = plural
130    return datetime.timedelta(**{unit: float(this)})