Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(
 67            this=exp.RowNumber(),
 68            partition_by=distinct_cols,
 69        )
 70        order = expression.args.get("order")
 71        if order:
 72            window.set("order", order.pop().copy())
 73        window = exp.alias_(window, row_number)
 74        expression.select(window, copy=False)
 75        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
 76    return expression
 77
 78
 79def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 80    """
 81    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 82
 83    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 84    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 85
 86    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 87    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 88    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 89    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 90    """
 91    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 92        taken = set(expression.named_selects)
 93        for select in expression.selects:
 94            if not select.alias_or_name:
 95                alias = find_new_name(taken, "_c")
 96                select.replace(exp.alias_(select.copy(), alias))
 97                taken.add(alias)
 98
 99        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
100        qualify_filters = expression.args["qualify"].pop().this
101
102        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
103            if isinstance(expr, exp.Window):
104                alias = find_new_name(expression.named_selects, "_w")
105                expression.select(exp.alias_(expr.copy(), alias), copy=False)
106                column = exp.column(alias)
107                if isinstance(expr.parent, exp.Qualify):
108                    qualify_filters = column
109                else:
110                    expr.replace(column)
111            elif expr.name not in expression.named_selects:
112                expression.select(expr.copy(), copy=False)
113
114        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
115
116    return expression
117
118
119def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
120    """
121    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
122    other expressions. This transforms removes the precision from parameterized types in expressions.
123    """
124    for node in expression.find_all(exp.DataType):
125        node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
126    return expression
127
128
129def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
130    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
131    if isinstance(expression, exp.Select):
132        for join in expression.args.get("joins") or []:
133            unnest = join.this
134
135            if isinstance(unnest, exp.Unnest):
136                alias = unnest.args.get("alias")
137                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
138
139                expression.args["joins"].remove(join)
140
141                for e, column in zip(unnest.expressions, alias.columns if alias else []):
142                    expression.append(
143                        "laterals",
144                        exp.Lateral(
145                            this=udtf(this=e),
146                            view=True,
147                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
148                        ),
149                    )
150    return expression
151
152
153def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
154    """Convert explode/posexplode into unnest (used in hive -> presto)."""
155    if isinstance(expression, exp.Select):
156        from sqlglot.optimizer.scope import build_scope
157
158        taken_select_names = set(expression.named_selects)
159        taken_source_names = set(build_scope(expression).selected_sources)
160
161        for select in expression.selects:
162            to_replace = select
163
164            pos_alias = ""
165            explode_alias = ""
166
167            if isinstance(select, exp.Alias):
168                explode_alias = select.alias
169                select = select.this
170            elif isinstance(select, exp.Aliases):
171                pos_alias = select.aliases[0].name
172                explode_alias = select.aliases[1].name
173                select = select.this
174
175            if isinstance(select, (exp.Explode, exp.Posexplode)):
176                is_posexplode = isinstance(select, exp.Posexplode)
177
178                explode_arg = select.this
179                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
180
181                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
182                if isinstance(explode_arg, exp.Column):
183                    taken_select_names.add(explode_arg.output_name)
184
185                unnest_source_alias = find_new_name(taken_source_names, "_u")
186                taken_source_names.add(unnest_source_alias)
187
188                if not explode_alias:
189                    explode_alias = find_new_name(taken_select_names, "col")
190                    taken_select_names.add(explode_alias)
191
192                    if is_posexplode:
193                        pos_alias = find_new_name(taken_select_names, "pos")
194                        taken_select_names.add(pos_alias)
195
196                if is_posexplode:
197                    column_names = [explode_alias, pos_alias]
198                    to_replace.pop()
199                    expression.select(pos_alias, explode_alias, copy=False)
200                else:
201                    column_names = [explode_alias]
202                    to_replace.replace(exp.column(explode_alias))
203
204                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
205
206                if not expression.args.get("from"):
207                    expression.from_(unnest, copy=False)
208                else:
209                    expression.join(unnest, join_type="CROSS", copy=False)
210
211    return expression
212
213
214def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
215    """Remove table refs from columns in when statements."""
216    if isinstance(expression, exp.Merge):
217        alias = expression.this.args.get("alias")
218        targets = {expression.this.this}
219        if alias:
220            targets.add(alias.this)
221
222        for when in expression.expressions:
223            when.transform(
224                lambda node: exp.column(node.name)
225                if isinstance(node, exp.Column) and node.args.get("table") in targets
226                else node,
227                copy=False,
228            )
229    return expression
230
231
232def preprocess(
233    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
234) -> t.Callable[[Generator, exp.Expression], str]:
235    """
236    Creates a new transform by chaining a sequence of transformations and converts the resulting
237    expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
238
239    Args:
240        transforms: sequence of transform functions. These will be called in order.
241
242    Returns:
243        Function that can be used as a generator transform.
244    """
245
246    def _to_sql(self, expression: exp.Expression) -> str:
247        expression = transforms[0](expression.copy())
248        for t in transforms[1:]:
249            expression = t(expression)
250        return getattr(self, expression.key + "_sql")(expression)
251
252    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(
68            this=exp.RowNumber(),
69            partition_by=distinct_cols,
70        )
71        order = expression.args.get("order")
72        if order:
73            window.set("order", order.pop().copy())
74        window = exp.alias_(window, row_number)
75        expression.select(window, copy=False)
76        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
77    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 81    """
 82    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 83
 84    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 85    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 86
 87    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 88    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 89    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 90    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 91    """
 92    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 93        taken = set(expression.named_selects)
 94        for select in expression.selects:
 95            if not select.alias_or_name:
 96                alias = find_new_name(taken, "_c")
 97                select.replace(exp.alias_(select.copy(), alias))
 98                taken.add(alias)
 99
100        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
101        qualify_filters = expression.args["qualify"].pop().this
102
103        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
104            if isinstance(expr, exp.Window):
105                alias = find_new_name(expression.named_selects, "_w")
106                expression.select(exp.alias_(expr.copy(), alias), copy=False)
107                column = exp.column(alias)
108                if isinstance(expr.parent, exp.Qualify):
109                    qualify_filters = column
110                else:
111                    expr.replace(column)
112            elif expr.name not in expression.named_selects:
113                expression.select(expr.copy(), copy=False)
114
115        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
116
117    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
120def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
121    """
122    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
123    other expressions. This transforms removes the precision from parameterized types in expressions.
124    """
125    for node in expression.find_all(exp.DataType):
126        node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
127    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
130def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
131    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
132    if isinstance(expression, exp.Select):
133        for join in expression.args.get("joins") or []:
134            unnest = join.this
135
136            if isinstance(unnest, exp.Unnest):
137                alias = unnest.args.get("alias")
138                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
139
140                expression.args["joins"].remove(join)
141
142                for e, column in zip(unnest.expressions, alias.columns if alias else []):
143                    expression.append(
144                        "laterals",
145                        exp.Lateral(
146                            this=udtf(this=e),
147                            view=True,
148                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
149                        ),
150                    )
151    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
154def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
155    """Convert explode/posexplode into unnest (used in hive -> presto)."""
156    if isinstance(expression, exp.Select):
157        from sqlglot.optimizer.scope import build_scope
158
159        taken_select_names = set(expression.named_selects)
160        taken_source_names = set(build_scope(expression).selected_sources)
161
162        for select in expression.selects:
163            to_replace = select
164
165            pos_alias = ""
166            explode_alias = ""
167
168            if isinstance(select, exp.Alias):
169                explode_alias = select.alias
170                select = select.this
171            elif isinstance(select, exp.Aliases):
172                pos_alias = select.aliases[0].name
173                explode_alias = select.aliases[1].name
174                select = select.this
175
176            if isinstance(select, (exp.Explode, exp.Posexplode)):
177                is_posexplode = isinstance(select, exp.Posexplode)
178
179                explode_arg = select.this
180                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
181
182                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
183                if isinstance(explode_arg, exp.Column):
184                    taken_select_names.add(explode_arg.output_name)
185
186                unnest_source_alias = find_new_name(taken_source_names, "_u")
187                taken_source_names.add(unnest_source_alias)
188
189                if not explode_alias:
190                    explode_alias = find_new_name(taken_select_names, "col")
191                    taken_select_names.add(explode_alias)
192
193                    if is_posexplode:
194                        pos_alias = find_new_name(taken_select_names, "pos")
195                        taken_select_names.add(pos_alias)
196
197                if is_posexplode:
198                    column_names = [explode_alias, pos_alias]
199                    to_replace.pop()
200                    expression.select(pos_alias, explode_alias, copy=False)
201                else:
202                    column_names = [explode_alias]
203                    to_replace.replace(exp.column(explode_alias))
204
205                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
206
207                if not expression.args.get("from"):
208                    expression.from_(unnest, copy=False)
209                else:
210                    expression.join(unnest, join_type="CROSS", copy=False)
211
212    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_target_from_merge( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
215def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
216    """Remove table refs from columns in when statements."""
217    if isinstance(expression, exp.Merge):
218        alias = expression.this.args.get("alias")
219        targets = {expression.this.this}
220        if alias:
221            targets.add(alias.this)
222
223        for when in expression.expressions:
224            when.transform(
225                lambda node: exp.column(node.name)
226                if isinstance(node, exp.Column) and node.args.get("table") in targets
227                else node,
228                copy=False,
229            )
230    return expression

Remove table refs from columns in when statements.

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
233def preprocess(
234    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
235) -> t.Callable[[Generator, exp.Expression], str]:
236    """
237    Creates a new transform by chaining a sequence of transformations and converts the resulting
238    expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
239
240    Args:
241        transforms: sequence of transform functions. These will be called in order.
242
243    Returns:
244        Function that can be used as a generator transform.
245    """
246
247    def _to_sql(self, expression: exp.Expression) -> str:
248        expression = transforms[0](expression.copy())
249        for t in transforms[1:]:
250            expression = t(expression)
251        return getattr(self, expression.key + "_sql")(expression)
252
253    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using an appropriate Generator.TRANSFORMS function.

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.