Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(
 67            this=exp.RowNumber(),
 68            partition_by=distinct_cols,
 69        )
 70        order = expression.args.get("order")
 71        if order:
 72            window.set("order", order.pop().copy())
 73        window = exp.alias_(window, row_number)
 74        expression.select(window, copy=False)
 75        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
 76    return expression
 77
 78
 79def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 80    """
 81    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 82
 83    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 84    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 85
 86    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 87    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 88    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 89    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 90    """
 91    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 92        taken = set(expression.named_selects)
 93        for select in expression.selects:
 94            if not select.alias_or_name:
 95                alias = find_new_name(taken, "_c")
 96                select.replace(exp.alias_(select.copy(), alias))
 97                taken.add(alias)
 98
 99        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
100        qualify_filters = expression.args["qualify"].pop().this
101
102        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
103            if isinstance(expr, exp.Window):
104                alias = find_new_name(expression.named_selects, "_w")
105                expression.select(exp.alias_(expr.copy(), alias), copy=False)
106                column = exp.column(alias)
107                if isinstance(expr.parent, exp.Qualify):
108                    qualify_filters = column
109                else:
110                    expr.replace(column)
111            elif expr.name not in expression.named_selects:
112                expression.select(expr.copy(), copy=False)
113
114        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
115
116    return expression
117
118
119def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
120    """
121    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
122    other expressions. This transforms removes the precision from parameterized types in expressions.
123    """
124    return expression.transform(
125        lambda node: exp.DataType(
126            **{
127                **node.args,
128                "expressions": [
129                    node_expression
130                    for node_expression in node.expressions
131                    if isinstance(node_expression, exp.DataType)
132                ],
133            }
134        )
135        if isinstance(node, exp.DataType)
136        else node,
137    )
138
139
140def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
141    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
142    if isinstance(expression, exp.Select):
143        for join in expression.args.get("joins") or []:
144            unnest = join.this
145
146            if isinstance(unnest, exp.Unnest):
147                alias = unnest.args.get("alias")
148                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
149
150                expression.args["joins"].remove(join)
151
152                for e, column in zip(unnest.expressions, alias.columns if alias else []):
153                    expression.append(
154                        "laterals",
155                        exp.Lateral(
156                            this=udtf(this=e),
157                            view=True,
158                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
159                        ),
160                    )
161    return expression
162
163
164def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
165    """Convert explode/posexplode into unnest (used in hive -> presto)."""
166    if isinstance(expression, exp.Select):
167        from sqlglot.optimizer.scope import build_scope
168
169        taken_select_names = set(expression.named_selects)
170        taken_source_names = set(build_scope(expression).selected_sources)
171
172        for select in expression.selects:
173            to_replace = select
174
175            pos_alias = ""
176            explode_alias = ""
177
178            if isinstance(select, exp.Alias):
179                explode_alias = select.alias
180                select = select.this
181            elif isinstance(select, exp.Aliases):
182                pos_alias = select.aliases[0].name
183                explode_alias = select.aliases[1].name
184                select = select.this
185
186            if isinstance(select, (exp.Explode, exp.Posexplode)):
187                is_posexplode = isinstance(select, exp.Posexplode)
188
189                explode_arg = select.this
190                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
191
192                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
193                if isinstance(explode_arg, exp.Column):
194                    taken_select_names.add(explode_arg.output_name)
195
196                unnest_source_alias = find_new_name(taken_source_names, "_u")
197                taken_source_names.add(unnest_source_alias)
198
199                if not explode_alias:
200                    explode_alias = find_new_name(taken_select_names, "col")
201                    taken_select_names.add(explode_alias)
202
203                    if is_posexplode:
204                        pos_alias = find_new_name(taken_select_names, "pos")
205                        taken_select_names.add(pos_alias)
206
207                if is_posexplode:
208                    column_names = [explode_alias, pos_alias]
209                    to_replace.pop()
210                    expression.select(pos_alias, explode_alias, copy=False)
211                else:
212                    column_names = [explode_alias]
213                    to_replace.replace(exp.column(explode_alias))
214
215                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
216
217                if not expression.args.get("from"):
218                    expression.from_(unnest, copy=False)
219                else:
220                    expression.join(unnest, join_type="CROSS", copy=False)
221
222    return expression
223
224
225def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
226    """Remove table refs from columns in when statements."""
227    if isinstance(expression, exp.Merge):
228        alias = expression.this.args.get("alias")
229        targets = {expression.this.this}
230        if alias:
231            targets.add(alias.this)
232
233        for when in expression.expressions:
234            when.transform(
235                lambda node: exp.column(node.name)
236                if isinstance(node, exp.Column) and node.args.get("table") in targets
237                else node,
238                copy=False,
239            )
240    return expression
241
242
243def preprocess(
244    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
245) -> t.Callable[[Generator, exp.Expression], str]:
246    """
247    Creates a new transform by chaining a sequence of transformations and converts the resulting
248    expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
249
250    Args:
251        transforms: sequence of transform functions. These will be called in order.
252
253    Returns:
254        Function that can be used as a generator transform.
255    """
256
257    def _to_sql(self, expression: exp.Expression) -> str:
258        expression = transforms[0](expression.copy())
259        for t in transforms[1:]:
260            expression = t(expression)
261        return getattr(self, expression.key + "_sql")(expression)
262
263    return _to_sql
264
265
266UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
267ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
268ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
269REMOVE_PRECISION_PARAMETERIZED_TYPES = {
270    exp.Cast: preprocess([remove_precision_parameterized_types])
271}
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(
68            this=exp.RowNumber(),
69            partition_by=distinct_cols,
70        )
71        order = expression.args.get("order")
72        if order:
73            window.set("order", order.pop().copy())
74        window = exp.alias_(window, row_number)
75        expression.select(window, copy=False)
76        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
77    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 81    """
 82    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 83
 84    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 85    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 86
 87    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 88    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 89    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 90    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 91    """
 92    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 93        taken = set(expression.named_selects)
 94        for select in expression.selects:
 95            if not select.alias_or_name:
 96                alias = find_new_name(taken, "_c")
 97                select.replace(exp.alias_(select.copy(), alias))
 98                taken.add(alias)
 99
100        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
101        qualify_filters = expression.args["qualify"].pop().this
102
103        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
104            if isinstance(expr, exp.Window):
105                alias = find_new_name(expression.named_selects, "_w")
106                expression.select(exp.alias_(expr.copy(), alias), copy=False)
107                column = exp.column(alias)
108                if isinstance(expr.parent, exp.Qualify):
109                    qualify_filters = column
110                else:
111                    expr.replace(column)
112            elif expr.name not in expression.named_selects:
113                expression.select(expr.copy(), copy=False)
114
115        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
116
117    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
120def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
121    """
122    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
123    other expressions. This transforms removes the precision from parameterized types in expressions.
124    """
125    return expression.transform(
126        lambda node: exp.DataType(
127            **{
128                **node.args,
129                "expressions": [
130                    node_expression
131                    for node_expression in node.expressions
132                    if isinstance(node_expression, exp.DataType)
133                ],
134            }
135        )
136        if isinstance(node, exp.DataType)
137        else node,
138    )

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
141def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
142    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
143    if isinstance(expression, exp.Select):
144        for join in expression.args.get("joins") or []:
145            unnest = join.this
146
147            if isinstance(unnest, exp.Unnest):
148                alias = unnest.args.get("alias")
149                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
150
151                expression.args["joins"].remove(join)
152
153                for e, column in zip(unnest.expressions, alias.columns if alias else []):
154                    expression.append(
155                        "laterals",
156                        exp.Lateral(
157                            this=udtf(this=e),
158                            view=True,
159                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
160                        ),
161                    )
162    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
165def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
166    """Convert explode/posexplode into unnest (used in hive -> presto)."""
167    if isinstance(expression, exp.Select):
168        from sqlglot.optimizer.scope import build_scope
169
170        taken_select_names = set(expression.named_selects)
171        taken_source_names = set(build_scope(expression).selected_sources)
172
173        for select in expression.selects:
174            to_replace = select
175
176            pos_alias = ""
177            explode_alias = ""
178
179            if isinstance(select, exp.Alias):
180                explode_alias = select.alias
181                select = select.this
182            elif isinstance(select, exp.Aliases):
183                pos_alias = select.aliases[0].name
184                explode_alias = select.aliases[1].name
185                select = select.this
186
187            if isinstance(select, (exp.Explode, exp.Posexplode)):
188                is_posexplode = isinstance(select, exp.Posexplode)
189
190                explode_arg = select.this
191                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
192
193                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
194                if isinstance(explode_arg, exp.Column):
195                    taken_select_names.add(explode_arg.output_name)
196
197                unnest_source_alias = find_new_name(taken_source_names, "_u")
198                taken_source_names.add(unnest_source_alias)
199
200                if not explode_alias:
201                    explode_alias = find_new_name(taken_select_names, "col")
202                    taken_select_names.add(explode_alias)
203
204                    if is_posexplode:
205                        pos_alias = find_new_name(taken_select_names, "pos")
206                        taken_select_names.add(pos_alias)
207
208                if is_posexplode:
209                    column_names = [explode_alias, pos_alias]
210                    to_replace.pop()
211                    expression.select(pos_alias, explode_alias, copy=False)
212                else:
213                    column_names = [explode_alias]
214                    to_replace.replace(exp.column(explode_alias))
215
216                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
217
218                if not expression.args.get("from"):
219                    expression.from_(unnest, copy=False)
220                else:
221                    expression.join(unnest, join_type="CROSS", copy=False)
222
223    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_target_from_merge( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
226def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
227    """Remove table refs from columns in when statements."""
228    if isinstance(expression, exp.Merge):
229        alias = expression.this.args.get("alias")
230        targets = {expression.this.this}
231        if alias:
232            targets.add(alias.this)
233
234        for when in expression.expressions:
235            when.transform(
236                lambda node: exp.column(node.name)
237                if isinstance(node, exp.Column) and node.args.get("table") in targets
238                else node,
239                copy=False,
240            )
241    return expression

Remove table refs from columns in when statements.

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
244def preprocess(
245    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
246) -> t.Callable[[Generator, exp.Expression], str]:
247    """
248    Creates a new transform by chaining a sequence of transformations and converts the resulting
249    expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
250
251    Args:
252        transforms: sequence of transform functions. These will be called in order.
253
254    Returns:
255        Function that can be used as a generator transform.
256    """
257
258    def _to_sql(self, expression: exp.Expression) -> str:
259        expression = transforms[0](expression.copy())
260        for t in transforms[1:]:
261            expression = t(expression)
262        return getattr(self, expression.key + "_sql")(expression)
263
264    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using an appropriate Generator.TRANSFORMS function.

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.