summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/canonicalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r--sqlglot/optimizer/canonicalize.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
new file mode 100644
index 0000000..9b3d98a
--- /dev/null
+++ b/sqlglot/optimizer/canonicalize.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import exp
+
+
+def canonicalize(expression: exp.Expression) -> exp.Expression:
+ """Converts a sql expression into a standard form.
+
+ This method relies on annotate_types because many of the
+ conversions rely on type inference.
+
+ Args:
+ expression: The expression to canonicalize.
+ """
+ exp.replace_children(expression, canonicalize)
+ expression = add_text_to_concat(expression)
+ expression = coerce_type(expression)
+ return expression
+
+
+def add_text_to_concat(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ node = exp.Concat(this=node.this, expression=node.expression)
+ return node
+
+
+def coerce_type(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Binary):
+ _coerce_date(node.left, node.right)
+ elif isinstance(node, exp.Between):
+ _coerce_date(node.this, node.args["low"])
+ elif isinstance(node, exp.Extract):
+ if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ _replace_cast(node.expression, "datetime")
+ return node
+
+
+def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
+ for a, b in itertools.permutations([a, b]):
+ if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ _replace_cast(b, "date")
+
+
+def _replace_cast(node: exp.Expression, to: str) -> None:
+ data_type = exp.DataType.build(to)
+ cast = exp.Cast(this=node.copy(), to=data_type)
+ cast.type = data_type
+ node.replace(cast)