diff options
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 9b3d98a..33529a5 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression: The expression to canonicalize. """ exp.replace_children(expression, canonicalize) + expression = add_text_to_concat(expression) expression = coerce_type(expression) + expression = remove_redundant_casts(expression) + return expression def add_text_to_concat(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: + if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: node = exp.Concat(this=node.this, expression=node.expression) return node @@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression: elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) elif isinstance(node, exp.Extract): - if node.expression.type not in exp.DataType.TEMPORAL_TYPES: + if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: _replace_cast(node.expression, "datetime") return node +def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Cast) + and expression.to.type + and expression.this.type + and expression.to.type.this == expression.this.type.this + ): + return expression.this + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): - if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: + if ( + a.type + and a.type.this == exp.DataType.Type.DATE + and b.type + and b.type.this != exp.DataType.Type.DATE + ): _replace_cast(b, "date") |