diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:11:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:12:02 +0000 |
commit | 8d36f5966675e23bee7026ba37ae0647fbf47300 (patch) | |
tree | df4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/optimizer/simplify.py | |
parent | Releasing debian version 22.2.0-1. (diff) | |
download | sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip |
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 472 |
1 files changed, 356 insertions, 116 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 2e43d21..d9a0d2b 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -9,19 +9,25 @@ from decimal import Decimal import sqlglot from sqlglot import Dialect, exp -from sqlglot.helper import first, is_iterable, merge_ranges, while_changing +from sqlglot.helper import first, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType DateTruncBinaryTransform = t.Callable[ - [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] + [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] ] # Final means that an expression should not be simplified FINAL = "final" +# Value ranges for byte-sized signed/unsigned integers +TINYINT_MIN = -128 +TINYINT_MAX = 127 +UTINYINT_MIN = 0 +UTINYINT_MAX = 255 + class UnsupportedUnit(Exception): pass @@ -63,14 +69,14 @@ def simplify( group.meta[FINAL] = True for e in expression.selects: - for node, *_ in e.walk(): + for node in e.walk(): if node in groups: e.meta[FINAL] = True break having = expression.args.get("having") if having: - for node, *_ in having.walk(): + for node in having.walk(): if node in groups: having.meta[FINAL] = True break @@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False): r = extract_date(r) if not r: return None + # python won't compare date and datetime, but many engines will upcast + l, r = cast_as_datetime(l), cast_as_datetime(r) for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): @@ -431,7 +439,7 @@ def propagate_constants(expression, root=True): and sqlglot.optimizer.normalize.normalized(expression, dnf=True) ): constant_mapping = {} - for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): + for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): if isinstance(expr, exp.EQ): l, r = expr.left, expr.right @@ -544,7 +552,37 @@ def simplify_literals(expression, root=True): return expression +NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) + + +def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: + if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): + this = _simplify_integer_cast(expr.this) + else: + this = expr.this + + if isinstance(expr, exp.Cast) and this.is_int: + num = int(this.name) + + # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any + # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is + # engine-dependent + if ( + TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES + ) or ( + UTINYINT_MIN <= num <= UTINYINT_MAX + and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES + ): + return this + + return expr + + def _simplify_binary(expression, a, b): + if isinstance(expression, COMPARISONS): + a = _simplify_integer_cast(a) + b = _simplify_integer_cast(b) + if isinstance(expression, exp.Is): if isinstance(b, exp.Not): c = b.this @@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b): return exp.true() if not_ else exp.false() if is_null(a): return exp.false() if not_ else exp.true() - elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): + elif isinstance(expression, NULL_OK): return None elif is_null(a) or is_null(b): return exp.null() @@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b): if boolean: return boolean elif _is_date_literal(a) and isinstance(b, exp.Interval): - a, b = extract_date(a), extract_interval(b) - if a and b: + date, b = extract_date(a), extract_interval(b) + if date and b: if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): - return date_literal(a + b) + return date_literal(date + b, extract_type(a)) if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): - return date_literal(a - b) + return date_literal(date - b, extract_type(a)) elif isinstance(a, exp.Interval) and _is_date_literal(b): - a, b = extract_interval(a), extract_date(b) + a, date = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): - return date_literal(a + b) + return date_literal(a + date, extract_type(b)) elif _is_date_literal(a) and _is_date_literal(b): if isinstance(expression, exp.Predicate): a, b = extract_date(a), extract_date(b) @@ -618,12 +656,16 @@ def simplify_parens(expression): this = expression.this parent = expression.parent + parent_is_predicate = isinstance(parent, exp.Predicate) if not isinstance(this, exp.Select) and ( not isinstance(parent, (exp.Condition, exp.Binary)) or isinstance(parent, exp.Paren) - or not isinstance(this, exp.Binary) - or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) + or ( + not isinstance(this, exp.Binary) + and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) + ) + or (isinstance(this, exp.Predicate) and not parent_is_predicate) or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) @@ -632,24 +674,12 @@ def simplify_parens(expression): return expression -NONNULL_CONSTANTS = ( - exp.Literal, - exp.Boolean, -) - -CONSTANTS = ( - exp.Literal, - exp.Boolean, - exp.Null, -) - - def _is_nonnull_constant(expression: exp.Expression) -> bool: - return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) + return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) def _is_constant(expression: exp.Expression) -> bool: - return isinstance(expression, CONSTANTS) or _is_date_literal(expression) + return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) def simplify_coalesce(expression): @@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti return floor, floor + interval(unit) -def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: +def _datetrunc_eq_expression( + left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] +) -> exp.Expression: """Get the logical expression for a date range""" return exp.and_( - left >= date_literal(drange[0]), - left < date_literal(drange[1]), + left >= date_literal(drange[0], target_type), + left < date_literal(drange[1], target_type), copy=False, ) def _datetrunc_eq( - left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect + left: exp.Expression, + date: datetime.date, + unit: str, + dialect: Dialect, + target_type: t.Optional[exp.DataType], ) -> t.Optional[exp.Expression]: drange = _datetrunc_range(date, unit, dialect) if not drange: return None - return _datetrunc_eq_expression(left, drange) + return _datetrunc_eq_expression(left, drange, target_type) def _datetrunc_neq( - left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect + left: exp.Expression, + date: datetime.date, + unit: str, + dialect: Dialect, + target_type: t.Optional[exp.DataType], ) -> t.Optional[exp.Expression]: drange = _datetrunc_range(date, unit, dialect) if not drange: return None return exp.and_( - left < date_literal(drange[0]), - left >= date_literal(drange[1]), + left < date_literal(drange[0], target_type), + left >= date_literal(drange[1], target_type), copy=False, ) DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, dt, u, d: l - < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), - exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), - exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), - exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), + exp.LT: lambda l, dt, u, d, t: l + < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), + exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), + exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), + exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), exp.EQ: _datetrunc_eq, exp.NEQ: _datetrunc_neq, } @@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr comparison = expression.__class__ if isinstance(expression, DATETRUNCS): - date = extract_date(expression.this) + this = expression.this + trunc_type = extract_type(this) + date = extract_date(this) if date and expression.unit: - return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) + return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) elif comparison not in DATETRUNC_COMPARISONS: return expression @@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr return expression l = t.cast(exp.DateTrunc, l) + trunc_arg = l.this unit = l.unit.name.lower() date = extract_date(r) if not date: return expression - return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression - elif isinstance(expression, exp.In): + return ( + DATETRUNC_BINARY_COMPARISONS[comparison]( + trunc_arg, date, unit, dialect, extract_type(trunc_arg, r) + ) + or expression + ) + + if isinstance(expression, exp.In): l = expression.this rs = expression.expressions @@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr return expression ranges = merge_ranges(ranges) + target_type = extract_type(l, *rs) - return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) + return exp.or_( + *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False + ) return expression @@ -954,7 +1006,7 @@ JOINS = { def remove_where_true(expression): for where in expression.find_all(exp.Where): if always_true(where.this): - where.parent.set("where", None) + where.pop() for join in expression.find_all(exp.Join): if ( always_true(join.args.get("on")) @@ -962,7 +1014,7 @@ def remove_where_true(expression): and not join.args.get("method") and (join.side, join.kind) in JOINS ): - join.set("on", None) + join.args["on"].pop() join.set("side", None) join.set("kind", "CROSS") @@ -1067,15 +1119,25 @@ def extract_interval(expression): return None -def date_literal(date): - return exp.cast( - exp.Literal.string(date), - ( +def extract_type(*expressions): + target_type = None + for expression in expressions: + target_type = expression.to if isinstance(expression, exp.Cast) else expression.type + if target_type: + break + + return target_type + + +def date_literal(date, target_type=None): + if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): + target_type = ( exp.DataType.Type.DATETIME if isinstance(date, datetime.datetime) else exp.DataType.Type.DATE - ), - ) + ) + + return exp.cast(exp.Literal.string(date), target_type) def interval(unit: str, n: int = 1): @@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str: Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here. """ - if expression is None: - return "_" - if is_iterable(expression): - return ",".join(gen(e) for e in expression) - if not isinstance(expression, exp.Expression): - return str(expression) - - etype = type(expression) - if etype in GEN_MAP: - return GEN_MAP[etype](expression) - return f"{expression.key} {gen(expression.args.values())}" - - -GEN_MAP = { - exp.Add: lambda e: _binary(e, "+"), - exp.And: lambda e: _binary(e, "AND"), - exp.Anonymous: lambda e: _anonymous(e), - exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", - exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", - exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", - exp.Column: lambda e: ".".join(gen(p) for p in e.parts), - exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", - exp.Div: lambda e: _binary(e, "/"), - exp.Dot: lambda e: _binary(e, "."), - exp.EQ: lambda e: _binary(e, "="), - exp.GT: lambda e: _binary(e, ">"), - exp.GTE: lambda e: _binary(e, ">="), - exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, - exp.ILike: lambda e: _binary(e, "ILIKE"), - exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", - exp.Is: lambda e: _binary(e, "IS"), - exp.Like: lambda e: _binary(e, "LIKE"), - exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, - exp.LT: lambda e: _binary(e, "<"), - exp.LTE: lambda e: _binary(e, "<="), - exp.Mod: lambda e: _binary(e, "%"), - exp.Mul: lambda e: _binary(e, "*"), - exp.Neg: lambda e: _unary(e, "-"), - exp.NEQ: lambda e: _binary(e, "<>"), - exp.Not: lambda e: _unary(e, "NOT"), - exp.Null: lambda e: "NULL", - exp.Or: lambda e: _binary(e, "OR"), - exp.Paren: lambda e: f"({gen(e.this)})", - exp.Sub: lambda e: _binary(e, "-"), - exp.Subquery: lambda e: f"({gen(e.args.values())})", - exp.Table: lambda e: gen(e.args.values()), - exp.Var: lambda e: e.name, -} + return Gen().gen(expression) + + +class Gen: + def __init__(self): + self.stack = [] + self.sqls = [] + + def gen(self, expression: exp.Expression) -> str: + self.stack = [expression] + self.sqls.clear() + + while self.stack: + node = self.stack.pop() + + if isinstance(node, exp.Expression): + exp_handler_name = f"{node.key}_sql" + + if hasattr(self, exp_handler_name): + getattr(self, exp_handler_name)(node) + elif isinstance(node, exp.Func): + self._function(node) + else: + key = node.key.upper() + self.stack.append(f"{key} " if self._args(node) else key) + elif type(node) is list: + for n in reversed(node): + if n is not None: + self.stack.extend((n, ",")) + if node: + self.stack.pop() + else: + if node is not None: + self.sqls.append(str(node)) + return "".join(self.sqls) -def _anonymous(e: exp.Anonymous) -> str: - this = e.this - if isinstance(this, str): - name = this.upper() - elif isinstance(this, exp.Identifier): - name = f'"{this.name}"' if this.quoted else this.name.upper() - else: - raise ValueError( - f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + def add_sql(self, e: exp.Add) -> None: + self._binary(e, " + ") + + def alias_sql(self, e: exp.Alias) -> None: + self.stack.extend( + ( + e.args.get("alias"), + " AS ", + e.args.get("this"), + ) + ) + + def and_sql(self, e: exp.And) -> None: + self._binary(e, " AND ") + + def anonymous_sql(self, e: exp.Anonymous) -> None: + this = e.this + if isinstance(this, str): + name = this.upper() + elif isinstance(this, exp.Identifier): + name = this.this + name = f'"{name}"' if this.quoted else name.upper() + else: + raise ValueError( + f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + ) + + self.stack.extend( + ( + ")", + e.expressions, + "(", + name, + ) + ) + + def between_sql(self, e: exp.Between) -> None: + self.stack.extend( + ( + e.args.get("high"), + " AND ", + e.args.get("low"), + " BETWEEN ", + e.this, + ) + ) + + def boolean_sql(self, e: exp.Boolean) -> None: + self.stack.append("TRUE" if e.this else "FALSE") + + def bracket_sql(self, e: exp.Bracket) -> None: + self.stack.extend( + ( + "]", + e.expressions, + "[", + e.this, + ) + ) + + def column_sql(self, e: exp.Column) -> None: + for p in reversed(e.parts): + self.stack.extend((p, ".")) + self.stack.pop() + + def datatype_sql(self, e: exp.DataType) -> None: + self._args(e, 1) + self.stack.append(f"{e.this.name} ") + + def div_sql(self, e: exp.Div) -> None: + self._binary(e, " / ") + + def dot_sql(self, e: exp.Dot) -> None: + self._binary(e, ".") + + def eq_sql(self, e: exp.EQ) -> None: + self._binary(e, " = ") + + def from_sql(self, e: exp.From) -> None: + self.stack.extend((e.this, "FROM ")) + + def gt_sql(self, e: exp.GT) -> None: + self._binary(e, " > ") + + def gte_sql(self, e: exp.GTE) -> None: + self._binary(e, " >= ") + + def identifier_sql(self, e: exp.Identifier) -> None: + self.stack.append(f'"{e.this}"' if e.quoted else e.this) + + def ilike_sql(self, e: exp.ILike) -> None: + self._binary(e, " ILIKE ") + + def in_sql(self, e: exp.In) -> None: + self.stack.append(")") + self._args(e, 1) + self.stack.extend( + ( + "(", + " IN ", + e.this, + ) ) - return f"{name} {','.join(gen(e) for e in e.expressions)}" + def intdiv_sql(self, e: exp.IntDiv) -> None: + self._binary(e, " DIV ") + + def is_sql(self, e: exp.Is) -> None: + self._binary(e, " IS ") + + def like_sql(self, e: exp.Like) -> None: + self._binary(e, " Like ") + + def literal_sql(self, e: exp.Literal) -> None: + self.stack.append(f"'{e.this}'" if e.is_string else e.this) + + def lt_sql(self, e: exp.LT) -> None: + self._binary(e, " < ") + + def lte_sql(self, e: exp.LTE) -> None: + self._binary(e, " <= ") + + def mod_sql(self, e: exp.Mod) -> None: + self._binary(e, " % ") + + def mul_sql(self, e: exp.Mul) -> None: + self._binary(e, " * ") + def neg_sql(self, e: exp.Neg) -> None: + self._unary(e, "-") + + def neq_sql(self, e: exp.NEQ) -> None: + self._binary(e, " <> ") + + def not_sql(self, e: exp.Not) -> None: + self._unary(e, "NOT ") + + def null_sql(self, e: exp.Null) -> None: + self.stack.append("NULL") + + def or_sql(self, e: exp.Or) -> None: + self._binary(e, " OR ") + + def paren_sql(self, e: exp.Paren) -> None: + self.stack.extend( + ( + ")", + e.this, + "(", + ) + ) + + def sub_sql(self, e: exp.Sub) -> None: + self._binary(e, " - ") + + def subquery_sql(self, e: exp.Subquery) -> None: + self._args(e, 2) + alias = e.args.get("alias") + if alias: + self.stack.append(alias) + self.stack.extend((")", e.this, "(")) + + def table_sql(self, e: exp.Table) -> None: + self._args(e, 4) + alias = e.args.get("alias") + if alias: + self.stack.append(alias) + for p in reversed(e.parts): + self.stack.extend((p, ".")) + self.stack.pop() + + def tablealias_sql(self, e: exp.TableAlias) -> None: + columns = e.columns + + if columns: + self.stack.extend((")", columns, "(")) + + self.stack.extend((e.this, " AS ")) + + def var_sql(self, e: exp.Var) -> None: + self.stack.append(e.this) + + def _binary(self, e: exp.Binary, op: str) -> None: + self.stack.extend((e.expression, op, e.this)) + + def _unary(self, e: exp.Unary, op: str) -> None: + self.stack.extend((e.this, op)) + + def _function(self, e: exp.Func) -> None: + self.stack.extend( + ( + ")", + list(e.args.values()), + "(", + e.sql_name(), + ) + ) -def _binary(e: exp.Binary, op: str) -> str: - return f"{gen(e.left)} {op} {gen(e.right)}" + def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: + kvs = [] + arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types + for k in arg_types or arg_types: + v = node.args.get(k) -def _unary(e: exp.Unary, op: str) -> str: - return f"{op} {gen(e.this)}" + if v is not None: + kvs.append([f":{k}", v]) + if kvs: + self.stack.append(kvs) + return True + return False |