diff options
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c0719f2..f560760 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b): return boolean elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) - if b: + if a and b: if isinstance(expression, exp.Add): return date_literal(a + b) if isinstance(expression, exp.Sub): @@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b): elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval - if a and isinstance(expression, exp.Add): + if a and b and isinstance(expression, exp.Add): return date_literal(a + b) return None @@ -424,9 +424,15 @@ def eval_boolean(expression, a, b): def extract_date(cast): - if cast.args["to"].this == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(cast.name) - return None + # The "fromisoformat" conversion could fail if the cast is used on an identifier, + # so in that case we can't extract the date. + try: + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + if cast.args["to"].this == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(cast.name) + except ValueError: + return None def extract_interval(interval): @@ -450,7 +456,8 @@ def extract_interval(interval): def date_literal(date): - return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") + return exp.Cast(this=exp.Literal.string(date), to=expr_type) def boolean_literal(condition): |