summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/optimizer/simplify.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py73
1 files changed, 49 insertions, 24 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index af03332..d4e2e60 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -507,6 +507,9 @@ def simplify_literals(expression, root=True):
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
+ if type(expression) in INVERSE_DATE_OPS:
+ return _simplify_binary(expression, expression.this, expression.interval()) or expression
+
return expression
@@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b):
return exp.null()
if a.is_number and b.is_number:
- a = int(a.name) if a.is_int else Decimal(a.name)
- b = int(b.name) if b.is_int else Decimal(b.name)
+ num_a = int(a.name) if a.is_int else Decimal(a.name)
+ num_b = int(b.name) if b.is_int else Decimal(b.name)
if isinstance(expression, exp.Add):
- return exp.Literal.number(a + b)
- if isinstance(expression, exp.Sub):
- return exp.Literal.number(a - b)
+ return exp.Literal.number(num_a + num_b)
if isinstance(expression, exp.Mul):
- return exp.Literal.number(a * b)
+ return exp.Literal.number(num_a * num_b)
+
+ # We only simplify Sub, Div if a and b have the same parent because they're not associative
+ if isinstance(expression, exp.Sub):
+ return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
if isinstance(expression, exp.Div):
# engines have differing int div behavior so intdiv is not safe
- if isinstance(a, int) and isinstance(b, int):
+ if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
return None
- return exp.Literal.number(a / b)
+ return exp.Literal.number(num_a / num_b)
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, num_a, num_b)
if boolean:
return boolean
@@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b):
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
- if isinstance(expression, exp.Add):
+ if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
return date_literal(a + b)
- if isinstance(expression, exp.Sub):
+ if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
+ elif _is_date_literal(a) and _is_date_literal(b):
+ if isinstance(expression, exp.Predicate):
+ a, b = extract_date(a), extract_date(b)
+ boolean = eval_boolean(expression, a, b)
+ if boolean:
+ return boolean
return None
@@ -590,6 +601,11 @@ def simplify_parens(expression):
return expression
+NONNULL_CONSTANTS = (
+ exp.Literal,
+ exp.Boolean,
+)
+
CONSTANTS = (
exp.Literal,
exp.Boolean,
@@ -597,11 +613,19 @@ CONSTANTS = (
)
+def _is_nonnull_constant(expression: exp.Expression) -> bool:
+ return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
+
+
+def _is_constant(expression: exp.Expression) -> bool:
+ return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
+
+
def simplify_coalesce(expression):
# COALESCE(x) -> x
if (
isinstance(expression, exp.Coalesce)
- and not expression.expressions
+ and (not expression.expressions or _is_nonnull_constant(expression.this))
# COALESCE is also used as a Spark partitioning hint
and not isinstance(expression.parent, exp.Hint)
):
@@ -621,12 +645,12 @@ def simplify_coalesce(expression):
# This transformation is valid for non-constants,
# but it really only does anything if they are both constants.
- if not isinstance(other, CONSTANTS):
+ if not _is_constant(other):
return expression
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
- if isinstance(arg, CONSTANTS):
+ if _is_constant(other):
break
else:
return expression
@@ -656,7 +680,6 @@ def simplify_coalesce(expression):
CONCATS = (exp.Concat, exp.DPipe)
-SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
@@ -672,10 +695,15 @@ def simplify_concat(expression):
sep_expr, *expressions = expression.expressions
sep = sep_expr.name
concat_type = exp.ConcatWs
+ args = {}
else:
expressions = expression.expressions
sep = ""
- concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+ concat_type = exp.Concat
+ args = {
+ "safe": expression.args.get("safe"),
+ "coalesce": expression.args.get("coalesce"),
+ }
new_args = []
for is_string_group, group in itertools.groupby(
@@ -692,7 +720,7 @@ def simplify_concat(expression):
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
- return concat_type(expressions=new_args)
+ return concat_type(expressions=new_args, **args)
def simplify_conditionals(expression):
@@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if isinstance(cast, exp.Cast):
to = cast.to
- elif isinstance(cast, exp.TsOrDsToDate):
+ elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
to = exp.DataType.build(exp.DataType.Type.DATE)
else:
return None
@@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool:
def extract_interval(expression):
- n = int(expression.name)
- unit = expression.text("unit").lower()
-
try:
+ n = int(expression.name)
+ unit = expression.text("unit").lower()
return interval(unit, n)
- except (UnsupportedUnit, ModuleNotFoundError):
+ except (UnsupportedUnit, ModuleNotFoundError, ValueError):
return None
@@ -1099,8 +1126,6 @@ GEN_MAP = {
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
- exp.DPipe: lambda e: _binary(e, "||"),
- exp.SafeDPipe: lambda e: _binary(e, "||"),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),