summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/canonicalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r--sqlglot/optimizer/canonicalize.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index fc37a54..c5c780d 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -1,9 +1,12 @@
+from __future__ import annotations
+
import itertools
from sqlglot import exp
+from sqlglot.helper import should_identify
-def canonicalize(expression: exp.Expression) -> exp.Expression:
+def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
+ identify: Whether or not to force identify identifier.
"""
- exp.replace_children(expression, canonicalize)
+ exp.replace_children(expression, canonicalize, identify=identify)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
+ expression = ensure_bool_predicates(expression)
if isinstance(expression, exp.Identifier):
- expression.set("quoted", True)
+ if should_identify(expression.this, identify):
+ expression.set("quoted", True)
return expression
@@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
+def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Connector):
+ _replace_int_predicate(expression.left)
+ _replace_int_predicate(expression.right)
+
+ elif isinstance(expression, (exp.Where, exp.Having)):
+ _replace_int_predicate(expression.this)
+
+ return expression
+
+
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if (
@@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None:
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)
+
+
+def _replace_int_predicate(expression: exp.Expression) -> None:
+ if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
+ expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))