diff options
Diffstat (limited to 'sqlglot/optimizer/canonicalize.py')
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 85 |
1 files changed, 73 insertions, 12 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index fc5c348..faf18c6 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,8 +1,10 @@ from __future__ import annotations import itertools +import typing as t from sqlglot import exp +from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime def canonicalize(expression: exp.Expression) -> exp.Expression: @@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression = replace_date_funcs(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) - expression = ensure_bool_predicates(expression) + expression = ensure_bools(expression, _replace_int_predicate) expression = remove_ascending_order(expression) return expression @@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression: return node +COERCIBLE_DATE_OPS = ( + exp.Add, + exp.Sub, + exp.EQ, + exp.NEQ, + exp.GT, + exp.GTE, + exp.LT, + exp.LTE, + exp.NullSafeEQ, + exp.NullSafeNEQ, +) + + def coerce_type(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Binary): + if isinstance(node, COERCIBLE_DATE_OPS): _coerce_date(node.left, node.right) elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) @@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression: *exp.DataType.TEMPORAL_TYPES ): _replace_cast(node.expression, exp.DataType.Type.DATETIME) + elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): + _coerce_timeunit_arg(node.this, node.unit) + elif isinstance(node, exp.DateDiff): + _coerce_datediff_args(node) return node @@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: return expression -def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: +def ensure_bools( + expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] +) -> exp.Expression: if isinstance(expression, exp.Connector): - _replace_int_predicate(expression.left) - _replace_int_predicate(expression.right) - - elif isinstance(expression, (exp.Where, exp.Having)) or ( + replace_func(expression.left) + replace_func(expression.right) + elif isinstance(expression, exp.Not): + replace_func(expression.this) # We can't replace num in CASE x WHEN num ..., because it's not the full predicate - isinstance(expression, exp.If) - and not (isinstance(expression.parent, exp.Case) and expression.parent.this) + elif isinstance(expression, exp.If) and not ( + isinstance(expression.parent, exp.Case) and expression.parent.this ): - _replace_int_predicate(expression.this) + replace_func(expression.this) + elif isinstance(expression, (exp.Where, exp.Having)): + replace_func(expression.this) return expression @@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression: def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): + if isinstance(b, exp.Interval): + a = _coerce_timeunit_arg(a, b.unit) if ( a.type and a.type.this == exp.DataType.Type.DATE and b.type - and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) + and b.type.this + not in ( + exp.DataType.Type.DATE, + exp.DataType.Type.INTERVAL, + ) ): _replace_cast(b, exp.DataType.Type.DATE) +def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: + if not arg.type: + return arg + + if arg.type.this in exp.DataType.TEXT_TYPES: + date_text = arg.name + is_iso_date_ = is_iso_date(date_text) + + if is_iso_date_ and is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date_ or is_iso_datetime(date_text): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + return arg + + +def _coerce_datediff_args(node: exp.DateDiff) -> None: + for e in (node.this, node.expression): + if e.type.this not in exp.DataType.TEMPORAL_TYPES: + e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) + + def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: node.replace(exp.cast(node.copy(), to=to)) +# this was originally designed for presto, there is a similar transform for tsql +# this is different in that it only operates on int types, this is because +# presto has a boolean type whereas tsql doesn't (people use bits) +# with y as (select true as x) select x = 0 FROM y -- illegal presto query def _replace_int_predicate(expression: exp.Expression) -> None: if isinstance(expression, exp.Coalesce): for _, child in expression.iter_expressions(): _replace_int_predicate(child) elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: - expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) + expression.replace(expression.neq(0)) |