summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/canonicalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r--sqlglot/optimizer/canonicalize.py85
1 files changed, 73 insertions, 12 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index fc5c348..faf18c6 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -1,8 +1,10 @@
from __future__ import annotations
import itertools
+import typing as t
from sqlglot import exp
+from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
def canonicalize(expression: exp.Expression) -> exp.Expression:
@@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
- expression = ensure_bool_predicates(expression)
+ expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
return expression
@@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
return node
+COERCIBLE_DATE_OPS = (
+ exp.Add,
+ exp.Sub,
+ exp.EQ,
+ exp.NEQ,
+ exp.GT,
+ exp.GTE,
+ exp.LT,
+ exp.LTE,
+ exp.NullSafeEQ,
+ exp.NullSafeNEQ,
+)
+
+
def coerce_type(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Binary):
+ if isinstance(node, COERCIBLE_DATE_OPS):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
@@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
*exp.DataType.TEMPORAL_TYPES
):
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
+ elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
+ _coerce_timeunit_arg(node.this, node.unit)
+ elif isinstance(node, exp.DateDiff):
+ _coerce_datediff_args(node)
return node
@@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
-def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
+def ensure_bools(
+ expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
+) -> exp.Expression:
if isinstance(expression, exp.Connector):
- _replace_int_predicate(expression.left)
- _replace_int_predicate(expression.right)
-
- elif isinstance(expression, (exp.Where, exp.Having)) or (
+ replace_func(expression.left)
+ replace_func(expression.right)
+ elif isinstance(expression, exp.Not):
+ replace_func(expression.this)
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
- isinstance(expression, exp.If)
- and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
+ elif isinstance(expression, exp.If) and not (
+ isinstance(expression.parent, exp.Case) and expression.parent.this
):
- _replace_int_predicate(expression.this)
+ replace_func(expression.this)
+ elif isinstance(expression, (exp.Where, exp.Having)):
+ replace_func(expression.this)
return expression
@@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
+ if isinstance(b, exp.Interval):
+ a = _coerce_timeunit_arg(a, b.unit)
if (
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
- and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
+ and b.type.this
+ not in (
+ exp.DataType.Type.DATE,
+ exp.DataType.Type.INTERVAL,
+ )
):
_replace_cast(b, exp.DataType.Type.DATE)
+def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
+ if not arg.type:
+ return arg
+
+ if arg.type.this in exp.DataType.TEXT_TYPES:
+ date_text = arg.name
+ is_iso_date_ = is_iso_date(date_text)
+
+ if is_iso_date_ and is_date_unit(unit):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
+
+ # An ISO date is also an ISO datetime, but not vice versa
+ if is_iso_date_ or is_iso_datetime(date_text):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
+
+ elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
+
+ return arg
+
+
+def _coerce_datediff_args(node: exp.DateDiff) -> None:
+ for e in (node.this, node.expression):
+ if e.type.this not in exp.DataType.TEMPORAL_TYPES:
+ e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
+
+
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))
+# this was originally designed for presto, there is a similar transform for tsql
+# this is different in that it only operates on int types, this is because
+# presto has a boolean type whereas tsql doesn't (people use bits)
+# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
if isinstance(expression, exp.Coalesce):
for _, child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
- expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
+ expression.replace(expression.neq(0))