summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
blob: 7fc71dd98bd79c3b4e8f55160d815974f0f9cf76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from sqlglot import expressions as exp


def unalias_group(expression):
    """
    Replace references to select aliases in GROUP BY clauses.

    Example:
        >>> import sqlglot
        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
        'SELECT a AS b FROM x GROUP BY 1'
    """
    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
        aliased_selects = {
            e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
        }

        expression = expression.copy()

        for col in expression.find_all(exp.Column):
            alias_index = aliased_selects.get(col.name)
            if not col.table and alias_index:
                col.replace(exp.Literal.number(alias_index))

    return expression


def preprocess(transforms, to_sql):
    """
    Create a new transform function that can be used a value in `Generator.TRANSFORMS`
    to convert expressions to SQL.

    Args:
        transforms (list[(exp.Expression) -> exp.Expression]):
            Sequence of transform functions. These will be called in order.
        to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
            Final transform that converts the resulting expression to a SQL string.
    Returns:
        (sqlglot.generator.Generator, exp.Expression) -> str:
            Function that can be used as a generator transform.
    """

    def _to_sql(self, expression):
        expression = transforms[0](expression)
        for t in transforms[1:]:
            expression = t(expression)
        return to_sql(self, expression)

    return _to_sql


def delegate(attr):
    """
    Create a new method that delegates to `attr`.

    This is useful for creating `Generator.TRANSFORMS` functions that delegate
    to existing generator methods.
    """

    def _transform(self, *args, **kwargs):
        return getattr(self, attr)(*args, **kwargs)

    return _transform


UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}