Edit on GitHub

sqlglot.executor.env

  1import datetime
  2import inspect
  3import re
  4import statistics
  5from functools import wraps
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import PYTHON_VERSION
 10
 11
 12class reverse_key:
 13    def __init__(self, obj):
 14        self.obj = obj
 15
 16    def __eq__(self, other):
 17        return other.obj == self.obj
 18
 19    def __lt__(self, other):
 20        return other.obj < self.obj
 21
 22
 23def filter_nulls(func, empty_null=True):
 24    @wraps(func)
 25    def _func(values):
 26        filtered = tuple(v for v in values if v is not None)
 27        if not filtered and empty_null:
 28            return None
 29        return func(filtered)
 30
 31    return _func
 32
 33
 34def null_if_any(*required):
 35    """
 36    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
 37
 38    This also supports decoration with no arguments, e.g.:
 39
 40        @null_if_any
 41        def foo(a, b): ...
 42
 43    In which case all arguments are required.
 44    """
 45    f = None
 46    if len(required) == 1 and callable(required[0]):
 47        f = required[0]
 48        required = ()
 49
 50    def decorator(func):
 51        if required:
 52            required_indices = [
 53                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
 54            ]
 55
 56            def predicate(*args):
 57                return any(args[i] is None for i in required_indices)
 58
 59        else:
 60
 61            def predicate(*args):
 62                return any(a is None for a in args)
 63
 64        @wraps(func)
 65        def _func(*args):
 66            if predicate(*args):
 67                return None
 68            return func(*args)
 69
 70        return _func
 71
 72    if f:
 73        return decorator(f)
 74
 75    return decorator
 76
 77
 78@null_if_any("substr", "this")
 79def str_position(substr, this, position=None):
 80    position = position - 1 if position is not None else position
 81    return this.find(substr, position) + 1
 82
 83
 84@null_if_any("this")
 85def substring(this, start=None, length=None):
 86    if start is None:
 87        return this
 88    elif start == 0:
 89        return ""
 90    elif start < 0:
 91        start = len(this) + start
 92    else:
 93        start -= 1
 94
 95    end = None if length is None else start + length
 96
 97    return this[start:end]
 98
 99
100@null_if_any
101def cast(this, to):
102    if to == exp.DataType.Type.DATE:
103        if isinstance(this, datetime.datetime):
104            return this.date()
105        if isinstance(this, datetime.date):
106            return this
107        if isinstance(this, str):
108            return datetime.date.fromisoformat(this)
109    if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
110        if isinstance(this, datetime.datetime):
111            return this
112        if isinstance(this, datetime.date):
113            return datetime.datetime(this.year, this.month, this.day)
114        if isinstance(this, str):
115            return datetime.datetime.fromisoformat(this)
116    if to == exp.DataType.Type.BOOLEAN:
117        return bool(this)
118    if to in exp.DataType.TEXT_TYPES:
119        return str(this)
120    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
121        return float(this)
122    if to in exp.DataType.NUMERIC_TYPES:
123        return int(this)
124    raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
125
126
127def ordered(this, desc, nulls_first):
128    if desc:
129        return reverse_key(this)
130    return this
131
132
133@null_if_any
134def interval(this, unit):
135    unit = unit.lower()
136    plural = unit + "s"
137    if plural in Generator.TIME_PART_SINGULARS:
138        unit = plural
139    return datetime.timedelta(**{unit: float(this)})
140
141
142ENV = {
143    "exp": exp,
144    # aggs
145    "ARRAYAGG": list,
146    "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean),  # type: ignore
147    "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
148    "MAX": filter_nulls(max),
149    "MIN": filter_nulls(min),
150    "SUM": filter_nulls(sum),
151    # scalar functions
152    "ABS": null_if_any(lambda this: abs(this)),
153    "ADD": null_if_any(lambda e, this: e + this),
154    "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
155    "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
156    "BITWISEAND": null_if_any(lambda this, e: this & e),
157    "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
158    "BITWISEOR": null_if_any(lambda this, e: this | e),
159    "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
160    "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
161    "CAST": cast,
162    "COALESCE": lambda *args: next((a for a in args if a is not None), None),
163    "CONCAT": null_if_any(lambda *args: "".join(args)),
164    "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
165    "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
166    "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days),
167    "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
168    "DIV": null_if_any(lambda e, this: e / this),
169    "DOT": null_if_any(lambda e, this: e[this]),
170    "EQ": null_if_any(lambda this, e: this == e),
171    "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
172    "GT": null_if_any(lambda this, e: this > e),
173    "GTE": null_if_any(lambda this, e: this >= e),
174    "IF": lambda predicate, true, false: true if predicate else false,
175    "INTDIV": null_if_any(lambda e, this: e // this),
176    "INTERVAL": interval,
177    "LEFT": null_if_any(lambda this, e: this[:e]),
178    "LIKE": null_if_any(
179        lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
180    ),
181    "LOWER": null_if_any(lambda arg: arg.lower()),
182    "LT": null_if_any(lambda this, e: this < e),
183    "LTE": null_if_any(lambda this, e: this <= e),
184    "MAP": null_if_any(lambda *args: dict(zip(*args))),  # type: ignore
185    "MOD": null_if_any(lambda e, this: e % this),
186    "MUL": null_if_any(lambda e, this: e * this),
187    "NEQ": null_if_any(lambda this, e: this != e),
188    "ORD": null_if_any(ord),
189    "ORDERED": ordered,
190    "POW": pow,
191    "RIGHT": null_if_any(lambda this, e: this[-e:]),
192    "STRPOSITION": str_position,
193    "SUB": null_if_any(lambda e, this: e - this),
194    "SUBSTRING": substring,
195    "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
196    "UPPER": null_if_any(lambda arg: arg.upper()),
197    "YEAR": null_if_any(lambda arg: arg.year),
198    "MONTH": null_if_any(lambda arg: arg.month),
199    "DAY": null_if_any(lambda arg: arg.day),
200    "CURRENTDATETIME": datetime.datetime.now,
201    "CURRENTTIMESTAMP": datetime.datetime.now,
202    "CURRENTTIME": datetime.datetime.now,
203    "CURRENTDATE": datetime.date.today,
204    "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
205}
class reverse_key:
13class reverse_key:
14    def __init__(self, obj):
15        self.obj = obj
16
17    def __eq__(self, other):
18        return other.obj == self.obj
19
20    def __lt__(self, other):
21        return other.obj < self.obj
reverse_key(obj)
14    def __init__(self, obj):
15        self.obj = obj
obj
def filter_nulls(func, empty_null=True):
24def filter_nulls(func, empty_null=True):
25    @wraps(func)
26    def _func(values):
27        filtered = tuple(v for v in values if v is not None)
28        if not filtered and empty_null:
29            return None
30        return func(filtered)
31
32    return _func
def null_if_any(*required):
35def null_if_any(*required):
36    """
37    Decorator that makes a function return `None` if any of the `required` arguments are `None`.
38
39    This also supports decoration with no arguments, e.g.:
40
41        @null_if_any
42        def foo(a, b): ...
43
44    In which case all arguments are required.
45    """
46    f = None
47    if len(required) == 1 and callable(required[0]):
48        f = required[0]
49        required = ()
50
51    def decorator(func):
52        if required:
53            required_indices = [
54                i for i, param in enumerate(inspect.signature(func).parameters) if param in required
55            ]
56
57            def predicate(*args):
58                return any(args[i] is None for i in required_indices)
59
60        else:
61
62            def predicate(*args):
63                return any(a is None for a in args)
64
65        @wraps(func)
66        def _func(*args):
67            if predicate(*args):
68                return None
69            return func(*args)
70
71        return _func
72
73    if f:
74        return decorator(f)
75
76    return decorator

Decorator that makes a function return None if any of the required arguments are None.

This also supports decoration with no arguments, e.g.:

@null_if_any
def foo(a, b): ...

In which case all arguments are required.

@null_if_any('substr', 'this')
def str_position(substr, this, position=None):
79@null_if_any("substr", "this")
80def str_position(substr, this, position=None):
81    position = position - 1 if position is not None else position
82    return this.find(substr, position) + 1
@null_if_any('this')
def substring(this, start=None, length=None):
85@null_if_any("this")
86def substring(this, start=None, length=None):
87    if start is None:
88        return this
89    elif start == 0:
90        return ""
91    elif start < 0:
92        start = len(this) + start
93    else:
94        start -= 1
95
96    end = None if length is None else start + length
97
98    return this[start:end]
@null_if_any
def cast(this, to):
101@null_if_any
102def cast(this, to):
103    if to == exp.DataType.Type.DATE:
104        if isinstance(this, datetime.datetime):
105            return this.date()
106        if isinstance(this, datetime.date):
107            return this
108        if isinstance(this, str):
109            return datetime.date.fromisoformat(this)
110    if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
111        if isinstance(this, datetime.datetime):
112            return this
113        if isinstance(this, datetime.date):
114            return datetime.datetime(this.year, this.month, this.day)
115        if isinstance(this, str):
116            return datetime.datetime.fromisoformat(this)
117    if to == exp.DataType.Type.BOOLEAN:
118        return bool(this)
119    if to in exp.DataType.TEXT_TYPES:
120        return str(this)
121    if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
122        return float(this)
123    if to in exp.DataType.NUMERIC_TYPES:
124        return int(this)
125    raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
128def ordered(this, desc, nulls_first):
129    if desc:
130        return reverse_key(this)
131    return this
@null_if_any
def interval(this, unit):
134@null_if_any
135def interval(this, unit):
136    unit = unit.lower()
137    plural = unit + "s"
138    if plural in Generator.TIME_PART_SINGULARS:
139        unit = plural
140    return datetime.timedelta(**{unit: float(this)})
ENV = {'exp': <module 'sqlglot.expressions' from '/home/runner/work/sqlglot/sqlglot/sqlglot/expressions.py'>, 'ARRAYAGG': <class 'list'>, 'AVG': <function fmean>, 'COUNT': <function <lambda>>, 'MAX': <function max>, 'MIN': <function min>, 'SUM': <function sum>, 'ABS': <function <lambda>>, 'ADD': <function <lambda>>, 'ARRAYANY': <function <lambda>>, 'BETWEEN': <function <lambda>>, 'BITWISEAND': <function <lambda>>, 'BITWISELEFTSHIFT': <function <lambda>>, 'BITWISEOR': <function <lambda>>, 'BITWISERIGHTSHIFT': <function <lambda>>, 'BITWISEXOR': <function <lambda>>, 'CAST': <function cast>, 'COALESCE': <function <lambda>>, 'CONCAT': <function <lambda>>, 'SAFECONCAT': <function <lambda>>, 'CONCATWS': <function <lambda>>, 'DATEDIFF': <function <lambda>>, 'DATESTRTODATE': <function <lambda>>, 'DIV': <function <lambda>>, 'DOT': <function <lambda>>, 'EQ': <function <lambda>>, 'EXTRACT': <function <lambda>>, 'GT': <function <lambda>>, 'GTE': <function <lambda>>, 'IF': <function <lambda>>, 'INTDIV': <function <lambda>>, 'INTERVAL': <function interval>, 'LEFT': <function <lambda>>, 'LIKE': <function <lambda>>, 'LOWER': <function <lambda>>, 'LT': <function <lambda>>, 'LTE': <function <lambda>>, 'MAP': <function <lambda>>, 'MOD': <function <lambda>>, 'MUL': <function <lambda>>, 'NEQ': <function <lambda>>, 'ORD': <function ord>, 'ORDERED': <function ordered>, 'POW': <built-in function pow>, 'RIGHT': <function <lambda>>, 'STRPOSITION': <function str_position>, 'SUB': <function <lambda>>, 'SUBSTRING': <function substring>, 'TIMESTRTOTIME': <function <lambda>>, 'UPPER': <function <lambda>>, 'YEAR': <function <lambda>>, 'MONTH': <function <lambda>>, 'DAY': <function <lambda>>, 'CURRENTDATETIME': <built-in method now of type object>, 'CURRENTTIMESTAMP': <built-in method now of type object>, 'CURRENTTIME': <built-in method now of type object>, 'CURRENTDATE': <built-in method today of type object>, 'STRFTIME': <function <lambda>>}