diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/optimizer/simplify.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 73 |
1 files changed, 49 insertions, 24 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index af03332..d4e2e60 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -507,6 +507,9 @@ def simplify_literals(expression, root=True): return exp.Literal.number(value[1:]) return exp.Literal.number(f"-{value}") + if type(expression) in INVERSE_DATE_OPS: + return _simplify_binary(expression, expression.this, expression.interval()) or expression + return expression @@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b): return exp.null() if a.is_number and b.is_number: - a = int(a.name) if a.is_int else Decimal(a.name) - b = int(b.name) if b.is_int else Decimal(b.name) + num_a = int(a.name) if a.is_int else Decimal(a.name) + num_b = int(b.name) if b.is_int else Decimal(b.name) if isinstance(expression, exp.Add): - return exp.Literal.number(a + b) - if isinstance(expression, exp.Sub): - return exp.Literal.number(a - b) + return exp.Literal.number(num_a + num_b) if isinstance(expression, exp.Mul): - return exp.Literal.number(a * b) + return exp.Literal.number(num_a * num_b) + + # We only simplify Sub, Div if a and b have the same parent because they're not associative + if isinstance(expression, exp.Sub): + return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None if isinstance(expression, exp.Div): # engines have differing int div behavior so intdiv is not safe - if isinstance(a, int) and isinstance(b, int): + if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: return None - return exp.Literal.number(a / b) + return exp.Literal.number(num_a / num_b) - boolean = eval_boolean(expression, a, b) + boolean = eval_boolean(expression, num_a, num_b) if boolean: return boolean @@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b): elif _is_date_literal(a) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) if a and b: - if isinstance(expression, exp.Add): + if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): return date_literal(a + b) - if isinstance(expression, exp.Sub): + if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): return date_literal(a - b) elif isinstance(a, exp.Interval) and _is_date_literal(b): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): return date_literal(a + b) + elif _is_date_literal(a) and _is_date_literal(b): + if isinstance(expression, exp.Predicate): + a, b = extract_date(a), extract_date(b) + boolean = eval_boolean(expression, a, b) + if boolean: + return boolean return None @@ -590,6 +601,11 @@ def simplify_parens(expression): return expression +NONNULL_CONSTANTS = ( + exp.Literal, + exp.Boolean, +) + CONSTANTS = ( exp.Literal, exp.Boolean, @@ -597,11 +613,19 @@ CONSTANTS = ( ) +def _is_nonnull_constant(expression: exp.Expression) -> bool: + return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) + + +def _is_constant(expression: exp.Expression) -> bool: + return isinstance(expression, CONSTANTS) or _is_date_literal(expression) + + def simplify_coalesce(expression): # COALESCE(x) -> x if ( isinstance(expression, exp.Coalesce) - and not expression.expressions + and (not expression.expressions or _is_nonnull_constant(expression.this)) # COALESCE is also used as a Spark partitioning hint and not isinstance(expression.parent, exp.Hint) ): @@ -621,12 +645,12 @@ def simplify_coalesce(expression): # This transformation is valid for non-constants, # but it really only does anything if they are both constants. - if not isinstance(other, CONSTANTS): + if not _is_constant(other): return expression # Find the first constant arg for arg_index, arg in enumerate(coalesce.expressions): - if isinstance(arg, CONSTANTS): + if _is_constant(other): break else: return expression @@ -656,7 +680,6 @@ def simplify_coalesce(expression): CONCATS = (exp.Concat, exp.DPipe) -SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) def simplify_concat(expression): @@ -672,10 +695,15 @@ def simplify_concat(expression): sep_expr, *expressions = expression.expressions sep = sep_expr.name concat_type = exp.ConcatWs + args = {} else: expressions = expression.expressions sep = "" - concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + concat_type = exp.Concat + args = { + "safe": expression.args.get("safe"), + "coalesce": expression.args.get("coalesce"), + } new_args = [] for is_string_group, group in itertools.groupby( @@ -692,7 +720,7 @@ def simplify_concat(expression): if concat_type is exp.ConcatWs: new_args = [sep_expr] + new_args - return concat_type(expressions=new_args) + return concat_type(expressions=new_args, **args) def simplify_conditionals(expression): @@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: if isinstance(cast, exp.Cast): to = cast.to - elif isinstance(cast, exp.TsOrDsToDate): + elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): to = exp.DataType.build(exp.DataType.Type.DATE) else: return None @@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool: def extract_interval(expression): - n = int(expression.name) - unit = expression.text("unit").lower() - try: + n = int(expression.name) + unit = expression.text("unit").lower() return interval(unit, n) - except (UnsupportedUnit, ModuleNotFoundError): + except (UnsupportedUnit, ModuleNotFoundError, ValueError): return None @@ -1099,8 +1126,6 @@ GEN_MAP = { exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", exp.Div: lambda e: _binary(e, "/"), exp.Dot: lambda e: _binary(e, "."), - exp.DPipe: lambda e: _binary(e, "||"), - exp.SafeDPipe: lambda e: _binary(e, "||"), exp.EQ: lambda e: _binary(e, "="), exp.GT: lambda e: _binary(e, ">"), exp.GTE: lambda e: _binary(e, ">="), |