summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/canonicalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r--sqlglot/optimizer/canonicalize.py25
1 files changed, 22 insertions, 3 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index 9b3d98a..33529a5 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
+
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
+ expression = remove_redundant_casts(expression)
+
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
@@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
- if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
+def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
+ if (
+ isinstance(expression, exp.Cast)
+ and expression.to.type
+ and expression.this.type
+ and expression.to.type.this == expression.this.type.this
+ ):
+ return expression.this
+ return expression
+
+
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
- if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ if (
+ a.type
+ and a.type.this == exp.DataType.Type.DATE
+ and b.type
+ and b.type.this != exp.DataType.Type.DATE
+ ):
_replace_cast(b, "date")