summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/canonicalize.py
blob: faf18c6fdc456206cb2e0f224c9840e498679a21 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations

import itertools
import typing as t

from sqlglot import exp
from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime


def canonicalize(expression: exp.Expression) -> exp.Expression:
    """Converts a sql expression into a standard form.

    This method relies on annotate_types because many of the
    conversions rely on type inference.

    Args:
        expression: The expression to canonicalize.
    """
    exp.replace_children(expression, canonicalize)

    expression = add_text_to_concat(expression)
    expression = replace_date_funcs(expression)
    expression = coerce_type(expression)
    expression = remove_redundant_casts(expression)
    expression = ensure_bools(expression, _replace_int_predicate)
    expression = remove_ascending_order(expression)

    return expression


def add_text_to_concat(node: exp.Expression) -> exp.Expression:
    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
        node = exp.Concat(expressions=[node.left, node.right])
    return node


def replace_date_funcs(node: exp.Expression) -> exp.Expression:
    if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
        return exp.cast(node.this, to=exp.DataType.Type.DATE)
    if isinstance(node, exp.Timestamp) and not node.expression:
        return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
    return node


COERCIBLE_DATE_OPS = (
    exp.Add,
    exp.Sub,
    exp.EQ,
    exp.NEQ,
    exp.GT,
    exp.GTE,
    exp.LT,
    exp.LTE,
    exp.NullSafeEQ,
    exp.NullSafeNEQ,
)


def coerce_type(node: exp.Expression) -> exp.Expression:
    if isinstance(node, COERCIBLE_DATE_OPS):
        _coerce_date(node.left, node.right)
    elif isinstance(node, exp.Between):
        _coerce_date(node.this, node.args["low"])
    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
        *exp.DataType.TEMPORAL_TYPES
    ):
        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
        _coerce_timeunit_arg(node.this, node.unit)
    elif isinstance(node, exp.DateDiff):
        _coerce_datediff_args(node)

    return node


def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
    if (
        isinstance(expression, exp.Cast)
        and expression.to.type
        and expression.this.type
        and expression.to.type.this == expression.this.type.this
    ):
        return expression.this
    return expression


def ensure_bools(
    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
) -> exp.Expression:
    if isinstance(expression, exp.Connector):
        replace_func(expression.left)
        replace_func(expression.right)
    elif isinstance(expression, exp.Not):
        replace_func(expression.this)
        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
    elif isinstance(expression, exp.If) and not (
        isinstance(expression.parent, exp.Case) and expression.parent.this
    ):
        replace_func(expression.this)
    elif isinstance(expression, (exp.Where, exp.Having)):
        replace_func(expression.this)

    return expression


def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
        # Convert ORDER BY a ASC to ORDER BY a
        expression.set("desc", None)

    return expression


def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
    for a, b in itertools.permutations([a, b]):
        if isinstance(b, exp.Interval):
            a = _coerce_timeunit_arg(a, b.unit)
        if (
            a.type
            and a.type.this == exp.DataType.Type.DATE
            and b.type
            and b.type.this
            not in (
                exp.DataType.Type.DATE,
                exp.DataType.Type.INTERVAL,
            )
        ):
            _replace_cast(b, exp.DataType.Type.DATE)


def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
    if not arg.type:
        return arg

    if arg.type.this in exp.DataType.TEXT_TYPES:
        date_text = arg.name
        is_iso_date_ = is_iso_date(date_text)

        if is_iso_date_ and is_date_unit(unit):
            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))

        # An ISO date is also an ISO datetime, but not vice versa
        if is_iso_date_ or is_iso_datetime(date_text):
            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))

    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))

    return arg


def _coerce_datediff_args(node: exp.DateDiff) -> None:
    for e in (node.this, node.expression):
        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))


def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
    node.replace(exp.cast(node.copy(), to=to))


# this was originally designed for presto, there is a similar transform for tsql
# this is different in that it only operates on int types, this is because
# presto has a boolean type whereas tsql doesn't (people use bits)
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
    if isinstance(expression, exp.Coalesce):
        for _, child in expression.iter_expressions():
            _replace_int_predicate(child)
    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
        expression.replace(expression.neq(0))