diff options
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index e45d1e3..ec3b3af 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression: _coerce_date(node.left, node.right) elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) - elif isinstance(node, exp.Extract): - if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: - _replace_cast(node.expression, "datetime") + elif isinstance(node, exp.Extract) and not node.expression.type.is_type( + *exp.DataType.TEMPORAL_TYPES + ): + _replace_cast(node.expression, exp.DataType.Type.DATETIME) + return node @@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: _replace_int_predicate(expression.left) _replace_int_predicate(expression.right) - elif isinstance(expression, (exp.Where, exp.Having)): + elif isinstance(expression, (exp.Where, exp.Having, exp.If)): _replace_int_predicate(expression.this) return expression @@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: and b.type and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) ): - _replace_cast(b, "date") + _replace_cast(b, exp.DataType.Type.DATE) -def _replace_cast(node: exp.Expression, to: str) -> None: +def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: node.replace(exp.cast(node.copy(), to=to)) def _replace_int_predicate(expression: exp.Expression) -> None: - if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: + if isinstance(expression, exp.Coalesce): + for _, child in expression.iter_expressions(): + _replace_int_predicate(child) + elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) |