Edit on GitHub

sqlglot.executor.env

  1import datetime
  2import inspect
  3import re
  4import statistics
  5from functools import wraps
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import PYTHON_VERSION, is_int, seq_get
 10
 11
 12class reverse_key:
 13    def __init__(self, obj):
 14        self.obj = obj
 15
 16    def __eq__(self, other):
 17        return other.obj == self.obj
 18
 19    def __lt__(self, other):
 20        return other.obj < self.obj
 21
 22
 23def filter_nulls(func, empty_null=True):
 24    @wraps(func)
 25    def _func(values):
 26        filtered = tuple(v for v in values if v is not None)
 27        if not filtered and empty_null:
 28            return None
 29        return func(filtered)
 30
 31    return _func
 32
 33
 34def null_if_any(*required):
 35    """
 36    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
 37
 38    This also supports decoration with no arguments, e.g.:
 39
 40        @null_if_any
 41        def foo(a, b): ...
 42
 43    In which case all arguments are required.
 44    """
 45    f = None
 46    if len(required) == 1 and callable(required[0]):
 47        f = required[0]
 48        required = ()
 49
 50    def decorator(func):
 51        if required:
 52            required_indices = [
 53                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
 54            ]
 55
 56            def predicate(*args):
 57                return any(args[i] is None for i in required_indices)
 58
 59        else:
 60
 61            def predicate(*args):
 62                return any(a is None for a in args)
 63
 64        @wraps(func)
 65        def _func(*args):
 66            if predicate(*args):
 67                return None
 68            return func(*args)
 69
 70        return _func
 71
 72    if f:
 73        return decorator(f)
 74
 75    return decorator
 76
 77
 78@null_if_any("substr", "this")
 79def str_position(substr, this, position=None):
 80    position = position - 1 if position is not None else position
 81    return this.find(substr, position) + 1
 82
 83
 84@null_if_any("this")
 85def substring(this, start=None, length=None):
 86    if start is None:
 87        return this
 88    elif start == 0:
 89        return ""
 90    elif start < 0:
 91        start = len(this) + start
 92    else:
 93        start -= 1
 94
 95    end = None if length is None else start + length
 96
 97    return this[start:end]
 98
 99
100@null_if_any
101def cast(this, to):
102    if to == exp.DataType.Type.DATE:
103        if isinstance(this, datetime.datetime):
104            return this.date()
105        if isinstance(this, datetime.date):
106            return this
107        if isinstance(this, str):
108            return datetime.date.fromisoformat(this)
109    if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
110        if isinstance(this, datetime.datetime):
111            return this
112        if isinstance(this, datetime.date):
113            return datetime.datetime(this.year, this.month, this.day)
114        if isinstance(this, str):
115            return datetime.datetime.fromisoformat(this)
116    if to == exp.DataType.Type.BOOLEAN:
117        return bool(this)
118    if to in exp.DataType.TEXT_TYPES:
119        return str(this)
120    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
121        return float(this)
122    if to in exp.DataType.NUMERIC_TYPES:
123        return int(this)
124    raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
125
126
127def ordered(this, desc, nulls_first):
128    if desc:
129        return reverse_key(this)
130    return this
131
132
133@null_if_any
134def interval(this, unit):
135    plural = unit + "S"
136    if plural in Generator.TIME_PART_SINGULARS:
137        unit = plural
138    return datetime.timedelta(**{unit.lower(): float(this)})
139
140
141@null_if_any("this", "expression")
142def arrayjoin(this, expression, null=None):
143    return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
144
145
146@null_if_any("this", "expression")
147def jsonextract(this, expression):
148    for path_segment in expression:
149        if isinstance(this, dict):
150            this = this.get(path_segment)
151        elif isinstance(this, list) and is_int(path_segment):
152            this = seq_get(this, int(path_segment))
153        else:
154            raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
155
156        if this is None:
157            break
158
159    return this
160
161
162ENV = {
163    "exp": exp,
164    # aggs
165    "ARRAYAGG": list,
166    "ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))),
167    "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean),  # type: ignore
168    "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
169    "MAX": filter_nulls(max),
170    "MIN": filter_nulls(min),
171    "SUM": filter_nulls(sum),
172    # scalar functions
173    "ABS": null_if_any(lambda this: abs(this)),
174    "ADD": null_if_any(lambda e, this: e + this),
175    "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
176    "ARRAYJOIN": arrayjoin,
177    "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
178    "BITWISEAND": null_if_any(lambda this, e: this & e),
179    "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
180    "BITWISEOR": null_if_any(lambda this, e: this | e),
181    "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
182    "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
183    "CAST": cast,
184    "COALESCE": lambda *args: next((a for a in args if a is not None), None),
185    "CONCAT": null_if_any(lambda *args: "".join(args)),
186    "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
187    "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
188    "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days),
189    "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
190    "DIV": null_if_any(lambda e, this: e / this),
191    "DOT": null_if_any(lambda e, this: e[this]),
192    "EQ": null_if_any(lambda this, e: this == e),
193    "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
194    "GT": null_if_any(lambda this, e: this > e),
195    "GTE": null_if_any(lambda this, e: this >= e),
196    "IF": lambda predicate, true, false: true if predicate else false,
197    "INTDIV": null_if_any(lambda e, this: e // this),
198    "INTERVAL": interval,
199    "JSONEXTRACT": jsonextract,
200    "LEFT": null_if_any(lambda this, e: this[:e]),
201    "LIKE": null_if_any(
202        lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
203    ),
204    "LOWER": null_if_any(lambda arg: arg.lower()),
205    "LT": null_if_any(lambda this, e: this < e),
206    "LTE": null_if_any(lambda this, e: this <= e),
207    "MAP": null_if_any(lambda *args: dict(zip(*args))),  # type: ignore
208    "MOD": null_if_any(lambda e, this: e % this),
209    "MUL": null_if_any(lambda e, this: e * this),
210    "NEQ": null_if_any(lambda this, e: this != e),
211    "ORD": null_if_any(ord),
212    "ORDERED": ordered,
213    "POW": pow,
214    "RIGHT": null_if_any(lambda this, e: this[-e:]),
215    "STRPOSITION": str_position,
216    "SUB": null_if_any(lambda e, this: e - this),
217    "SUBSTRING": substring,
218    "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
219    "UPPER": null_if_any(lambda arg: arg.upper()),
220    "YEAR": null_if_any(lambda arg: arg.year),
221    "MONTH": null_if_any(lambda arg: arg.month),
222    "DAY": null_if_any(lambda arg: arg.day),
223    "CURRENTDATETIME": datetime.datetime.now,
224    "CURRENTTIMESTAMP": datetime.datetime.now,
225    "CURRENTTIME": datetime.datetime.now,
226    "CURRENTDATE": datetime.date.today,
227    "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
228    "TRIM": null_if_any(lambda this, e=None: this.strip(e)),
229    "STRUCT": lambda *args: {
230        args[x]: args[x + 1]
231        for x in range(0, len(args), 2)
232        if (args[x + 1] is not None and args[x] is not None)
233    },
234}
class reverse_key:
13class reverse_key:
14    def __init__(self, obj):
15        self.obj = obj
16
17    def __eq__(self, other):
18        return other.obj == self.obj
19
20    def __lt__(self, other):
21        return other.obj < self.obj
reverse_key(obj)
14    def __init__(self, obj):
15        self.obj = obj
obj
def filter_nulls(func, empty_null=True):
24def filter_nulls(func, empty_null=True):
25    @wraps(func)
26    def _func(values):
27        filtered = tuple(v for v in values if v is not None)
28        if not filtered and empty_null:
29            return None
30        return func(filtered)
31
32    return _func
def null_if_any(*required):
35def null_if_any(*required):
36    """
37    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
38
39    This also supports decoration with no arguments, e.g.:
40
41        @null_if_any
42        def foo(a, b): ...
43
44    In which case all arguments are required.
45    """
46    f = None
47    if len(required) == 1 and callable(required[0]):
48        f = required[0]
49        required = ()
50
51    def decorator(func):
52        if required:
53            required_indices = [
54                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
55            ]
56
57            def predicate(*args):
58                return any(args[i] is None for i in required_indices)
59
60        else:
61
62            def predicate(*args):
63                return any(a is None for a in args)
64
65        @wraps(func)
66        def _func(*args):
67            if predicate(*args):
68                return None
69            return func(*args)
70
71        return _func
72
73    if f:
74        return decorator(f)
75
76    return decorator

Decorator that makes a function return None if any of the required arguments are None.

This also supports decoration with no arguments, e.g.:

@null_if_any
def foo(a, b): ...

In which case all arguments are required.

@null_if_any('substr', 'this')
def str_position(substr, this, position=None):
79@null_if_any("substr", "this")
80def str_position(substr, this, position=None):
81    position = position - 1 if position is not None else position
82    return this.find(substr, position) + 1
@null_if_any('this')
def substring(this, start=None, length=None):
85@null_if_any("this")
86def substring(this, start=None, length=None):
87    if start is None:
88        return this
89    elif start == 0:
90        return ""
91    elif start < 0:
92        start = len(this) + start
93    else:
94        start -= 1
95
96    end = None if length is None else start + length
97
98    return this[start:end]
@null_if_any
def cast(this, to):
101@null_if_any
102def cast(this, to):
103    if to == exp.DataType.Type.DATE:
104        if isinstance(this, datetime.datetime):
105            return this.date()
106        if isinstance(this, datetime.date):
107            return this
108        if isinstance(this, str):
109            return datetime.date.fromisoformat(this)
110    if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
111        if isinstance(this, datetime.datetime):
112            return this
113        if isinstance(this, datetime.date):
114            return datetime.datetime(this.year, this.month, this.day)
115        if isinstance(this, str):
116            return datetime.datetime.fromisoformat(this)
117    if to == exp.DataType.Type.BOOLEAN:
118        return bool(this)
119    if to in exp.DataType.TEXT_TYPES:
120        return str(this)
121    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
122        return float(this)
123    if to in exp.DataType.NUMERIC_TYPES:
124        return int(this)
125    raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
128def ordered(this, desc, nulls_first):
129    if desc:
130        return reverse_key(this)
131    return this
@null_if_any
def interval(this, unit):
134@null_if_any
135def interval(this, unit):
136    plural = unit + "S"
137    if plural in Generator.TIME_PART_SINGULARS:
138        unit = plural
139    return datetime.timedelta(**{unit.lower(): float(this)})
@null_if_any('this', 'expression')
def arrayjoin(this, expression, null=None):
142@null_if_any("this", "expression")
143def arrayjoin(this, expression, null=None):
144    return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
@null_if_any('this', 'expression')
def jsonextract(this, expression):
147@null_if_any("this", "expression")
148def jsonextract(this, expression):
149    for path_segment in expression:
150        if isinstance(this, dict):
151            this = this.get(path_segment)
152        elif isinstance(this, list) and is_int(path_segment):
153            this = seq_get(this, int(path_segment))
154        else:
155            raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
156
157        if this is None:
158            break
159
160    return this
ENV = {'exp': <module 'sqlglot.expressions' from '/home/runner/work/sqlglot/sqlglot/sqlglot/expressions.py'>, 'ARRAYAGG': <class 'list'>, 'ARRAYUNIQUEAGG': <function <lambda>>, 'AVG': <function fmean>, 'COUNT': <function <lambda>>, 'MAX': <function max>, 'MIN': <function min>, 'SUM': <function sum>, 'ABS': <function <lambda>>, 'ADD': <function <lambda>>, 'ARRAYANY': <function <lambda>>, 'ARRAYJOIN': <function arrayjoin>, 'BETWEEN': <function <lambda>>, 'BITWISEAND': <function <lambda>>, 'BITWISELEFTSHIFT': <function <lambda>>, 'BITWISEOR': <function <lambda>>, 'BITWISERIGHTSHIFT': <function <lambda>>, 'BITWISEXOR': <function <lambda>>, 'CAST': <function cast>, 'COALESCE': <function <lambda>>, 'CONCAT': <function <lambda>>, 'SAFECONCAT': <function <lambda>>, 'CONCATWS': <function <lambda>>, 'DATEDIFF': <function <lambda>>, 'DATESTRTODATE': <function <lambda>>, 'DIV': <function <lambda>>, 'DOT': <function <lambda>>, 'EQ': <function <lambda>>, 'EXTRACT': <function <lambda>>, 'GT': <function <lambda>>, 'GTE': <function <lambda>>, 'IF': <function <lambda>>, 'INTDIV': <function <lambda>>, 'INTERVAL': <function interval>, 'JSONEXTRACT': <function jsonextract>, 'LEFT': <function <lambda>>, 'LIKE': <function <lambda>>, 'LOWER': <function <lambda>>, 'LT': <function <lambda>>, 'LTE': <function <lambda>>, 'MAP': <function <lambda>>, 'MOD': <function <lambda>>, 'MUL': <function <lambda>>, 'NEQ': <function <lambda>>, 'ORD': <function ord>, 'ORDERED': <function ordered>, 'POW': <built-in function pow>, 'RIGHT': <function <lambda>>, 'STRPOSITION': <function str_position>, 'SUB': <function <lambda>>, 'SUBSTRING': <function substring>, 'TIMESTRTOTIME': <function <lambda>>, 'UPPER': <function <lambda>>, 'YEAR': <function <lambda>>, 'MONTH': <function <lambda>>, 'DAY': <function <lambda>>, 'CURRENTDATETIME': <built-in method now of type object>, 'CURRENTTIMESTAMP': <built-in method now of type object>, 'CURRENTTIME': <built-in method now of type object>, 'CURRENTDATE': <built-in method today of type object>, 'STRFTIME': <function <lambda>>, 'TRIM': <function <lambda>>, 'STRUCT': <function <lambda>>}