summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 849643c..30de75b 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False):
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
+ node = simplify_conditionals(node)
if constant_propagation:
node = propagate_constants(node, root)
@@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
return expression
if l.__class__ in INVERSE_DATE_OPS:
+ l = t.cast(exp.IntervalOp, l)
a = l.this
b = l.interval()
else:
+ l = t.cast(exp.Binary, l)
a, b = l.left, l.right
if not a_predicate(a) and b_predicate(b):
@@ -695,6 +698,32 @@ def simplify_concat(expression):
return concat_type(expressions=new_args)
+def simplify_conditionals(expression):
+ """Simplifies expressions like IF, CASE if their condition is statically known."""
+ if isinstance(expression, exp.Case):
+ this = expression.this
+ for case in expression.args["ifs"]:
+ cond = case.this
+ if this:
+ # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
+ cond = cond.replace(this.pop().eq(cond))
+
+ if always_true(cond):
+ return case.args["true"]
+
+ if always_false(cond):
+ case.pop()
+ if not expression.args["ifs"]:
+ return expression.args.get("default") or exp.null()
+ elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
+ if always_true(expression.this):
+ return expression.args["true"]
+ if always_false(expression.this):
+ return expression.args.get("false") or exp.null()
+
+ return expression
+
+
DateRange = t.Tuple[datetime.date, datetime.date]
@@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
else:
return expression
+ l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
date = extract_date(r)
@@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
rs = expression.expressions
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
+ l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
ranges = []
@@ -852,6 +883,10 @@ def always_true(expression):
)
+def always_false(expression):
+ return is_false(expression) or is_null(expression)
+
+
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a