summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:45 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:45 +0000
commita34653eb21369376f0e054dd989311afcb167f5b (patch)
tree5a0280adce195af0be654f79fd99395fd2932c19 /sqlglot/optimizer/simplify.py
parentReleasing debian version 18.7.0-1. (diff)
downloadsqlglot-a34653eb21369376f0e054dd989311afcb167f5b.tar.xz
sqlglot-a34653eb21369376f0e054dd989311afcb167f5b.zip
Merging upstream version 18.11.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py97
1 files changed, 70 insertions, 27 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d08c692..51214c4 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool:
return expression.is_number
-def _is_date(expression: exp.Expression) -> bool:
- return isinstance(expression, exp.Cast) and extract_date(expression) is not None
-
-
def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
@@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
- elif _is_date(r):
- a_predicate = _is_date
+ elif _is_date_literal(r):
+ a_predicate = _is_date_literal
b_predicate = _is_interval
else:
return expression
if l.__class__ in INVERSE_DATE_OPS:
a = l.this
- b = exp.Interval(
- this=l.expression.copy(),
- unit=l.unit.copy(),
- )
+ b = l.interval()
else:
a, b = l.left, l.right
@@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
- elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
+ elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
return date_literal(a - b)
- elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
+ elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
@@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
- return (
- isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
- and isinstance(right, exp.Cast)
- and right.is_type(*exp.DataType.TEMPORAL_TYPES)
- )
+ return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
@@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
unit = l.unit.name.lower()
date = extract_date(r)
+ if not date:
+ return expression
+
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
- if all(_is_datetrunc_predicate(l, r) for r in rs):
+ if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
unit = l.unit.name.lower()
- ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
+ ranges = []
+ for r in rs:
+ date = extract_date(r)
+ if not date:
+ return expression
+ drange = _datetrunc_range(date, unit)
+ if drange:
+ ranges.append(drange)
+
if not ranges:
return expression
@@ -811,18 +811,59 @@ def eval_boolean(expression, a, b):
return None
-def extract_date(cast):
- # The "fromisoformat" conversion could fail if the cast is used on an identifier,
- # so in that case we can't extract the date.
+def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
+ if isinstance(value, datetime.datetime):
+ return value.date()
+ if isinstance(value, datetime.date):
+ return value
try:
- if cast.args["to"].this == exp.DataType.Type.DATE:
- return datetime.date.fromisoformat(cast.name)
- if cast.args["to"].this == exp.DataType.Type.DATETIME:
- return datetime.datetime.fromisoformat(cast.name)
+ return datetime.datetime.fromisoformat(value).date()
except ValueError:
return None
+def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
+ if isinstance(value, datetime.datetime):
+ return value
+ if isinstance(value, datetime.date):
+ return datetime.datetime(year=value.year, month=value.month, day=value.day)
+ try:
+ return datetime.datetime.fromisoformat(value)
+ except ValueError:
+ return None
+
+
+def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
+ if not value:
+ return None
+ if to.is_type(exp.DataType.Type.DATE):
+ return cast_as_date(value)
+ if to.is_type(*exp.DataType.TEMPORAL_TYPES):
+ return cast_as_datetime(value)
+ return None
+
+
+def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
+ if isinstance(cast, exp.Cast):
+ to = cast.to
+ elif isinstance(cast, exp.TsOrDsToDate):
+ to = exp.DataType.build(exp.DataType.Type.DATE)
+ else:
+ return None
+
+ if isinstance(cast.this, exp.Literal):
+ value: t.Any = cast.this.name
+ elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
+ value = extract_date(cast.this)
+ else:
+ return None
+ return cast_value(value, to)
+
+
+def _is_date_literal(expression: exp.Expression) -> bool:
+ return extract_date(expression) is not None
+
+
def extract_interval(expression):
n = int(expression.name)
unit = expression.text("unit").lower()
@@ -836,7 +877,9 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
- "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
+ exp.DataType.Type.DATETIME
+ if isinstance(date, datetime.datetime)
+ else exp.DataType.Type.DATE,
)