Edit on GitHub

sqlglot.executor.env

  1import datetime
  2import inspect
  3import re
  4import statistics
  5from functools import wraps
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import PYTHON_VERSION, is_int, seq_get
 10
 11
 12class reverse_key:
 13    def __init__(self, obj):
 14        self.obj = obj
 15
 16    def __eq__(self, other):
 17        return other.obj == self.obj
 18
 19    def __lt__(self, other):
 20        return other.obj < self.obj
 21
 22
 23def filter_nulls(func, empty_null=True):
 24    @wraps(func)
 25    def _func(values):
 26        filtered = tuple(v for v in values if v is not None)
 27        if not filtered and empty_null:
 28            return None
 29        return func(filtered)
 30
 31    return _func
 32
 33
 34def null_if_any(*required):
 35    """
 36    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
 37
 38    This also supports decoration with no arguments, e.g.:
 39
 40        @null_if_any
 41        def foo(a, b): ...
 42
 43    In which case all arguments are required.
 44    """
 45    f = None
 46    if len(required) == 1 and callable(required[0]):
 47        f = required[0]
 48        required = ()
 49
 50    def decorator(func):
 51        if required:
 52            required_indices = [
 53                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
 54            ]
 55
 56            def predicate(*args):
 57                return any(args[i] is None for i in required_indices)
 58
 59        else:
 60
 61            def predicate(*args):
 62                return any(a is None for a in args)
 63
 64        @wraps(func)
 65        def _func(*args):
 66            if predicate(*args):
 67                return None
 68            return func(*args)
 69
 70        return _func
 71
 72    if f:
 73        return decorator(f)
 74
 75    return decorator
 76
 77
 78@null_if_any("substr", "this")
 79def str_position(substr, this, position=None):
 80    position = position - 1 if position is not None else position
 81    return this.find(substr, position) + 1
 82
 83
 84@null_if_any("this")
 85def substring(this, start=None, length=None):
 86    if start is None:
 87        return this
 88    elif start == 0:
 89        return ""
 90    elif start < 0:
 91        start = len(this) + start
 92    else:
 93        start -= 1
 94
 95    end = None if length is None else start + length
 96
 97    return this[start:end]
 98
 99
100@null_if_any
101def cast(this, to):
102    if to == exp.DataType.Type.DATE:
103        if isinstance(this, datetime.datetime):
104            return this.date()
105        if isinstance(this, datetime.date):
106            return this
107        if isinstance(this, str):
108            return datetime.date.fromisoformat(this)
109    if to == exp.DataType.Type.TIME:
110        if isinstance(this, datetime.datetime):
111            return this.time()
112        if isinstance(this, datetime.time):
113            return this
114        if isinstance(this, str):
115            return datetime.time.fromisoformat(this)
116    if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
117        if isinstance(this, datetime.datetime):
118            return this
119        if isinstance(this, datetime.date):
120            return datetime.datetime(this.year, this.month, this.day)
121        if isinstance(this, str):
122            return datetime.datetime.fromisoformat(this)
123    if to == exp.DataType.Type.BOOLEAN:
124        return bool(this)
125    if to in exp.DataType.TEXT_TYPES:
126        return str(this)
127    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
128        return float(this)
129    if to in exp.DataType.NUMERIC_TYPES:
130        return int(this)
131    raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
132
133
134def ordered(this, desc, nulls_first):
135    if desc:
136        return reverse_key(this)
137    return this
138
139
140@null_if_any
141def interval(this, unit):
142    plural = unit + "S"
143    if plural in Generator.TIME_PART_SINGULARS:
144        unit = plural
145    return datetime.timedelta(**{unit.lower(): float(this)})
146
147
148@null_if_any("this", "expression")
149def arraytostring(this, expression, null=None):
150    return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
151
152
153@null_if_any("this", "expression")
154def jsonextract(this, expression):
155    for path_segment in expression:
156        if isinstance(this, dict):
157            this = this.get(path_segment)
158        elif isinstance(this, list) and is_int(path_segment):
159            this = seq_get(this, int(path_segment))
160        else:
161            raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
162
163        if this is None:
164            break
165
166    return this
167
168
169ENV = {
170    "exp": exp,
171    # aggs
172    "ARRAYAGG": list,
173    "ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))),
174    "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean),  # type: ignore
175    "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
176    "MAX": filter_nulls(max),
177    "MIN": filter_nulls(min),
178    "SUM": filter_nulls(sum),
179    # scalar functions
180    "ABS": null_if_any(lambda this: abs(this)),
181    "ADD": null_if_any(lambda e, this: e + this),
182    "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
183    "ARRAYTOSTRING": arraytostring,
184    "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
185    "BITWISEAND": null_if_any(lambda this, e: this & e),
186    "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
187    "BITWISEOR": null_if_any(lambda this, e: this | e),
188    "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
189    "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
190    "CAST": cast,
191    "COALESCE": lambda *args: next((a for a in args if a is not None), None),
192    "CONCAT": null_if_any(lambda *args: "".join(args)),
193    "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
194    "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
195    "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days),
196    "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
197    "DIV": null_if_any(lambda e, this: e / this),
198    "DOT": null_if_any(lambda e, this: e[this]),
199    "EQ": null_if_any(lambda this, e: this == e),
200    "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
201    "GT": null_if_any(lambda this, e: this > e),
202    "GTE": null_if_any(lambda this, e: this >= e),
203    "IF": lambda predicate, true, false: true if predicate else false,
204    "INTDIV": null_if_any(lambda e, this: e // this),
205    "INTERVAL": interval,
206    "JSONEXTRACT": jsonextract,
207    "LEFT": null_if_any(lambda this, e: this[:e]),
208    "LIKE": null_if_any(
209        lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
210    ),
211    "LOWER": null_if_any(lambda arg: arg.lower()),
212    "LT": null_if_any(lambda this, e: this < e),
213    "LTE": null_if_any(lambda this, e: this <= e),
214    "MAP": null_if_any(lambda *args: dict(zip(*args))),  # type: ignore
215    "MOD": null_if_any(lambda e, this: e % this),
216    "MUL": null_if_any(lambda e, this: e * this),
217    "NEQ": null_if_any(lambda this, e: this != e),
218    "ORD": null_if_any(ord),
219    "ORDERED": ordered,
220    "POW": pow,
221    "RIGHT": null_if_any(lambda this, e: this[-e:]),
222    "ROUND": null_if_any(lambda this, decimals=None, truncate=None: round(this, ndigits=decimals)),
223    "STRPOSITION": str_position,
224    "SUB": null_if_any(lambda e, this: e - this),
225    "SUBSTRING": substring,
226    "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
227    "UPPER": null_if_any(lambda arg: arg.upper()),
228    "YEAR": null_if_any(lambda arg: arg.year),
229    "MONTH": null_if_any(lambda arg: arg.month),
230    "DAY": null_if_any(lambda arg: arg.day),
231    "CURRENTDATETIME": datetime.datetime.now,
232    "CURRENTTIMESTAMP": datetime.datetime.now,
233    "CURRENTTIME": datetime.datetime.now,
234    "CURRENTDATE": datetime.date.today,
235    "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
236    "STRTOTIME": null_if_any(lambda arg, format: datetime.datetime.strptime(arg, format)),
237    "TRIM": null_if_any(lambda this, e=None: this.strip(e)),
238    "STRUCT": lambda *args: {
239        args[x]: args[x + 1]
240        for x in range(0, len(args), 2)
241        if (args[x + 1] is not None and args[x] is not None)
242    },
243    "UNIXTOTIME": null_if_any(
244        lambda arg: datetime.datetime.fromtimestamp(arg, datetime.timezone.utc)
245    ),
246}
class reverse_key:
13class reverse_key:
14    def __init__(self, obj):
15        self.obj = obj
16
17    def __eq__(self, other):
18        return other.obj == self.obj
19
20    def __lt__(self, other):
21        return other.obj < self.obj
reverse_key(obj)
14    def __init__(self, obj):
15        self.obj = obj
obj
def filter_nulls(func, empty_null=True):
24def filter_nulls(func, empty_null=True):
25    @wraps(func)
26    def _func(values):
27        filtered = tuple(v for v in values if v is not None)
28        if not filtered and empty_null:
29            return None
30        return func(filtered)
31
32    return _func
def null_if_any(*required):
35def null_if_any(*required):
36    """
37    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
38
39    This also supports decoration with no arguments, e.g.:
40
41        @null_if_any
42        def foo(a, b): ...
43
44    In which case all arguments are required.
45    """
46    f = None
47    if len(required) == 1 and callable(required[0]):
48        f = required[0]
49        required = ()
50
51    def decorator(func):
52        if required:
53            required_indices = [
54                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
55            ]
56
57            def predicate(*args):
58                return any(args[i] is None for i in required_indices)
59
60        else:
61
62            def predicate(*args):
63                return any(a is None for a in args)
64
65        @wraps(func)
66        def _func(*args):
67            if predicate(*args):
68                return None
69            return func(*args)
70
71        return _func
72
73    if f:
74        return decorator(f)
75
76    return decorator

Decorator that makes a function return None if any of the required arguments are None.

This also supports decoration with no arguments, e.g.:

@null_if_any
def foo(a, b): ...

In which case all arguments are required.

@null_if_any('substr', 'this')
def str_position(substr, this, position=None):
79@null_if_any("substr", "this")
80def str_position(substr, this, position=None):
81    position = position - 1 if position is not None else position
82    return this.find(substr, position) + 1
@null_if_any('this')
def substring(this, start=None, length=None):
85@null_if_any("this")
86def substring(this, start=None, length=None):
87    if start is None:
88        return this
89    elif start == 0:
90        return ""
91    elif start < 0:
92        start = len(this) + start
93    else:
94        start -= 1
95
96    end = None if length is None else start + length
97
98    return this[start:end]
@null_if_any
def cast(this, to):
101@null_if_any
102def cast(this, to):
103    if to == exp.DataType.Type.DATE:
104        if isinstance(this, datetime.datetime):
105            return this.date()
106        if isinstance(this, datetime.date):
107            return this
108        if isinstance(this, str):
109            return datetime.date.fromisoformat(this)
110    if to == exp.DataType.Type.TIME:
111        if isinstance(this, datetime.datetime):
112            return this.time()
113        if isinstance(this, datetime.time):
114            return this
115        if isinstance(this, str):
116            return datetime.time.fromisoformat(this)
117    if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
118        if isinstance(this, datetime.datetime):
119            return this
120        if isinstance(this, datetime.date):
121            return datetime.datetime(this.year, this.month, this.day)
122        if isinstance(this, str):
123            return datetime.datetime.fromisoformat(this)
124    if to == exp.DataType.Type.BOOLEAN:
125        return bool(this)
126    if to in exp.DataType.TEXT_TYPES:
127        return str(this)
128    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
129        return float(this)
130    if to in exp.DataType.NUMERIC_TYPES:
131        return int(this)
132    raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
135def ordered(this, desc, nulls_first):
136    if desc:
137        return reverse_key(this)
138    return this
@null_if_any
def interval(this, unit):
141@null_if_any
142def interval(this, unit):
143    plural = unit + "S"
144    if plural in Generator.TIME_PART_SINGULARS:
145        unit = plural
146    return datetime.timedelta(**{unit.lower(): float(this)})
@null_if_any('this', 'expression')
def arraytostring(this, expression, null=None):
149@null_if_any("this", "expression")
150def arraytostring(this, expression, null=None):
151    return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
@null_if_any('this', 'expression')
def jsonextract(this, expression):
154@null_if_any("this", "expression")
155def jsonextract(this, expression):
156    for path_segment in expression:
157        if isinstance(this, dict):
158            this = this.get(path_segment)
159        elif isinstance(this, list) and is_int(path_segment):
160            this = seq_get(this, int(path_segment))
161        else:
162            raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
163
164        if this is None:
165            break
166
167    return this
ENV = {'exp': <module 'sqlglot.expressions' from '/home/runner/work/sqlglot/sqlglot/sqlglot/expressions.py'>, 'ARRAYAGG': <class 'list'>, 'ARRAYUNIQUEAGG': <function <lambda>>, 'AVG': <function fmean>, 'COUNT': <function <lambda>>, 'MAX': <function max>, 'MIN': <function min>, 'SUM': <function sum>, 'ABS': <function <lambda>>, 'ADD': <function <lambda>>, 'ARRAYANY': <function <lambda>>, 'ARRAYTOSTRING': <function arraytostring>, 'BETWEEN': <function <lambda>>, 'BITWISEAND': <function <lambda>>, 'BITWISELEFTSHIFT': <function <lambda>>, 'BITWISEOR': <function <lambda>>, 'BITWISERIGHTSHIFT': <function <lambda>>, 'BITWISEXOR': <function <lambda>>, 'CAST': <function cast>, 'COALESCE': <function <lambda>>, 'CONCAT': <function <lambda>>, 'SAFECONCAT': <function <lambda>>, 'CONCATWS': <function <lambda>>, 'DATEDIFF': <function <lambda>>, 'DATESTRTODATE': <function <lambda>>, 'DIV': <function <lambda>>, 'DOT': <function <lambda>>, 'EQ': <function <lambda>>, 'EXTRACT': <function <lambda>>, 'GT': <function <lambda>>, 'GTE': <function <lambda>>, 'IF': <function <lambda>>, 'INTDIV': <function <lambda>>, 'INTERVAL': <function interval>, 'JSONEXTRACT': <function jsonextract>, 'LEFT': <function <lambda>>, 'LIKE': <function <lambda>>, 'LOWER': <function <lambda>>, 'LT': <function <lambda>>, 'LTE': <function <lambda>>, 'MAP': <function <lambda>>, 'MOD': <function <lambda>>, 'MUL': <function <lambda>>, 'NEQ': <function <lambda>>, 'ORD': <function ord>, 'ORDERED': <function ordered>, 'POW': <built-in function pow>, 'RIGHT': <function <lambda>>, 'ROUND': <function <lambda>>, 'STRPOSITION': <function str_position>, 'SUB': <function <lambda>>, 'SUBSTRING': <function substring>, 'TIMESTRTOTIME': <function <lambda>>, 'UPPER': <function <lambda>>, 'YEAR': <function <lambda>>, 'MONTH': <function <lambda>>, 'DAY': <function <lambda>>, 'CURRENTDATETIME': <built-in method now of type object>, 'CURRENTTIMESTAMP': <built-in method now of type object>, 'CURRENTTIME': <built-in method now of type object>, 'CURRENTDATE': <built-in method today of type object>, 'STRFTIME': <function <lambda>>, 'STRTOTIME': <function <lambda>>, 'TRIM': <function <lambda>>, 'STRUCT': <function <lambda>>, 'UNIXTOTIME': <function <lambda>>}