diff options
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index fc37a54..c5c780d 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import itertools from sqlglot import exp +from sqlglot.helper import should_identify -def canonicalize(expression: exp.Expression) -> exp.Expression: +def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: Args: expression: The expression to canonicalize. + identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize) + exp.replace_children(expression, canonicalize, identify=identify) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) + expression = ensure_bool_predicates(expression) if isinstance(expression, exp.Identifier): - expression.set("quoted", True) + if should_identify(expression.this, identify): + expression.set("quoted", True) return expression @@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: return expression +def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Connector): + _replace_int_predicate(expression.left) + _replace_int_predicate(expression.right) + + elif isinstance(expression, (exp.Where, exp.Having)): + _replace_int_predicate(expression.this) + + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): if ( @@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None: cast = exp.Cast(this=node.copy(), to=data_type) cast.type = data_type node.replace(cast) + + +def _replace_int_predicate(expression: exp.Expression) -> None: + if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: + expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) |