summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/canonicalize.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/canonicalize.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index e45d1e3..ec3b3af 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
- elif isinstance(node, exp.Extract):
- if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
- _replace_cast(node.expression, "datetime")
+ elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
+ *exp.DataType.TEMPORAL_TYPES
+ ):
+ _replace_cast(node.expression, exp.DataType.Type.DATETIME)
+
return node
@@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
- elif isinstance(expression, (exp.Where, exp.Having)):
+ elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
_replace_int_predicate(expression.this)
return expression
@@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
and b.type
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
- _replace_cast(b, "date")
+ _replace_cast(b, exp.DataType.Type.DATE)
-def _replace_cast(node: exp.Expression, to: str) -> None:
+def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))
def _replace_int_predicate(expression: exp.Expression) -> None:
- if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
+ if isinstance(expression, exp.Coalesce):
+ for _, child in expression.iter_expressions():
+ _replace_int_predicate(child)
+ elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))