diff options
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 849643c..30de75b 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False): node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) node = simplify_concat(node) + node = simplify_conditionals(node) if constant_propagation: node = propagate_constants(node, root) @@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: return expression if l.__class__ in INVERSE_DATE_OPS: + l = t.cast(exp.IntervalOp, l) a = l.this b = l.interval() else: + l = t.cast(exp.Binary, l) a, b = l.left, l.right if not a_predicate(a) and b_predicate(b): @@ -695,6 +698,32 @@ def simplify_concat(expression): return concat_type(expressions=new_args) +def simplify_conditionals(expression): + """Simplifies expressions like IF, CASE if their condition is statically known.""" + if isinstance(expression, exp.Case): + this = expression.this + for case in expression.args["ifs"]: + cond = case.this + if this: + # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... + cond = cond.replace(this.pop().eq(cond)) + + if always_true(cond): + return case.args["true"] + + if always_false(cond): + case.pop() + if not expression.args["ifs"]: + return expression.args.get("default") or exp.null() + elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): + if always_true(expression.this): + return expression.args["true"] + if always_false(expression.this): + return expression.args.get("false") or exp.null() + + return expression + + DateRange = t.Tuple[datetime.date, datetime.date] @@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: else: return expression + l = t.cast(exp.DateTrunc, l) unit = l.unit.name.lower() date = extract_date(r) @@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: rs = expression.expressions if rs and all(_is_datetrunc_predicate(l, r) for r in rs): + l = t.cast(exp.DateTrunc, l) unit = l.unit.name.lower() ranges = [] @@ -852,6 +883,10 @@ def always_true(expression): ) +def always_false(expression): + return is_false(expression) or is_null(expression) + + def is_complement(a, b): return isinstance(b, exp.Not) and b.this == a |