Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.helper import find_new_name
  8
  9if t.TYPE_CHECKING:
 10    from sqlglot.generator import Generator
 11
 12
 13def unalias_group(expression: exp.Expression) -> exp.Expression:
 14    """
 15    Replace references to select aliases in GROUP BY clauses.
 16
 17    Example:
 18        >>> import sqlglot
 19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 20        'SELECT a AS b FROM x GROUP BY 1'
 21
 22    Args:
 23        expression: the expression that will be transformed.
 24
 25    Returns:
 26        The transformed expression.
 27    """
 28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 29        aliased_selects = {
 30            e.alias: i
 31            for i, e in enumerate(expression.parent.expressions, start=1)
 32            if isinstance(e, exp.Alias)
 33        }
 34
 35        for group_by in expression.expressions:
 36            if (
 37                isinstance(group_by, exp.Column)
 38                and not group_by.table
 39                and group_by.name in aliased_selects
 40            ):
 41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 42
 43    return expression
 44
 45
 46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 47    """
 48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 49
 50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 51
 52    Args:
 53        expression: the expression that will be transformed.
 54
 55    Returns:
 56        The transformed expression.
 57    """
 58    if (
 59        isinstance(expression, exp.Select)
 60        and expression.args.get("distinct")
 61        and expression.args["distinct"].args.get("on")
 62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 63    ):
 64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 65        outer_selects = expression.selects
 66        row_number = find_new_name(expression.named_selects, "_row_number")
 67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 68        order = expression.args.get("order")
 69
 70        if order:
 71            window.set("order", order.pop().copy())
 72
 73        window = exp.alias_(window, row_number)
 74        expression.select(window, copy=False)
 75
 76        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
 77
 78    return expression
 79
 80
 81def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 82    """
 83    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 84
 85    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 86    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 87
 88    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 89    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 90    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 91    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 92    """
 93    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 94        taken = set(expression.named_selects)
 95        for select in expression.selects:
 96            if not select.alias_or_name:
 97                alias = find_new_name(taken, "_c")
 98                select.replace(exp.alias_(select, alias))
 99                taken.add(alias)
100
101        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
102        qualify_filters = expression.args["qualify"].pop().this
103
104        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
105            if isinstance(expr, exp.Window):
106                alias = find_new_name(expression.named_selects, "_w")
107                expression.select(exp.alias_(expr, alias), copy=False)
108                column = exp.column(alias)
109
110                if isinstance(expr.parent, exp.Qualify):
111                    qualify_filters = column
112                else:
113                    expr.replace(column)
114            elif expr.name not in expression.named_selects:
115                expression.select(expr.copy(), copy=False)
116
117        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
118
119    return expression
120
121
122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
123    """
124    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
125    other expressions. This transforms removes the precision from parameterized types in expressions.
126    """
127    for node in expression.find_all(exp.DataType):
128        node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
129
130    return expression
131
132
133def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
134    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
135    if isinstance(expression, exp.Select):
136        for join in expression.args.get("joins") or []:
137            unnest = join.this
138
139            if isinstance(unnest, exp.Unnest):
140                alias = unnest.args.get("alias")
141                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
142
143                expression.args["joins"].remove(join)
144
145                for e, column in zip(unnest.expressions, alias.columns if alias else []):
146                    expression.append(
147                        "laterals",
148                        exp.Lateral(
149                            this=udtf(this=e),
150                            view=True,
151                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
152                        ),
153                    )
154
155    return expression
156
157
158def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
159    """Convert explode/posexplode into unnest (used in hive -> presto)."""
160    if isinstance(expression, exp.Select):
161        from sqlglot.optimizer.scope import build_scope
162
163        taken_select_names = set(expression.named_selects)
164        scope = build_scope(expression)
165        if not scope:
166            return expression
167        taken_source_names = set(scope.selected_sources)
168
169        for select in expression.selects:
170            to_replace = select
171
172            pos_alias = ""
173            explode_alias = ""
174
175            if isinstance(select, exp.Alias):
176                explode_alias = select.alias
177                select = select.this
178            elif isinstance(select, exp.Aliases):
179                pos_alias = select.aliases[0].name
180                explode_alias = select.aliases[1].name
181                select = select.this
182
183            if isinstance(select, (exp.Explode, exp.Posexplode)):
184                is_posexplode = isinstance(select, exp.Posexplode)
185
186                explode_arg = select.this
187                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
188
189                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
190                if isinstance(explode_arg, exp.Column):
191                    taken_select_names.add(explode_arg.output_name)
192
193                unnest_source_alias = find_new_name(taken_source_names, "_u")
194                taken_source_names.add(unnest_source_alias)
195
196                if not explode_alias:
197                    explode_alias = find_new_name(taken_select_names, "col")
198                    taken_select_names.add(explode_alias)
199
200                    if is_posexplode:
201                        pos_alias = find_new_name(taken_select_names, "pos")
202                        taken_select_names.add(pos_alias)
203
204                if is_posexplode:
205                    column_names = [explode_alias, pos_alias]
206                    to_replace.pop()
207                    expression.select(pos_alias, explode_alias, copy=False)
208                else:
209                    column_names = [explode_alias]
210                    to_replace.replace(exp.column(explode_alias))
211
212                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
213
214                if not expression.args.get("from"):
215                    expression.from_(unnest, copy=False)
216                else:
217                    expression.join(unnest, join_type="CROSS", copy=False)
218
219    return expression
220
221
222def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
223    """Remove table refs from columns in when statements."""
224    if isinstance(expression, exp.Merge):
225        alias = expression.this.args.get("alias")
226        targets = {expression.this.this}
227        if alias:
228            targets.add(alias.this)
229
230        for when in expression.expressions:
231            when.transform(
232                lambda node: exp.column(node.name)
233                if isinstance(node, exp.Column) and node.args.get("table") in targets
234                else node,
235                copy=False,
236            )
237
238    return expression
239
240
241def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
242    if (
243        isinstance(expression, exp.WithinGroup)
244        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
245        and isinstance(expression.expression, exp.Order)
246    ):
247        quantile = expression.this.this
248        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
249        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
250
251    return expression
252
253
254def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
255    if isinstance(expression, exp.With) and expression.recursive:
256        sequence = itertools.count()
257        next_name = lambda: f"_c_{next(sequence)}"
258
259        for cte in expression.expressions:
260            if not cte.args["alias"].columns:
261                query = cte.this
262                if isinstance(query, exp.Union):
263                    query = query.this
264
265                cte.args["alias"].set(
266                    "columns",
267                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
268                )
269
270    return expression
271
272
273def preprocess(
274    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
275) -> t.Callable[[Generator, exp.Expression], str]:
276    """
277    Creates a new transform by chaining a sequence of transformations and converts the resulting
278    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
279    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
280
281    Args:
282        transforms: sequence of transform functions. These will be called in order.
283
284    Returns:
285        Function that can be used as a generator transform.
286    """
287
288    def _to_sql(self, expression: exp.Expression) -> str:
289        expression_type = type(expression)
290
291        expression = transforms[0](expression.copy())
292        for t in transforms[1:]:
293            expression = t(expression)
294
295        _sql_handler = getattr(self, expression.key + "_sql", None)
296        if _sql_handler:
297            return _sql_handler(expression)
298
299        transforms_handler = self.TRANSFORMS.get(type(expression))
300        if transforms_handler:
301            # Ensures we don't enter an infinite loop. This can happen when the original expression
302            # has the same type as the final expression and there's no _sql method available for it,
303            # because then it'd re-enter _to_sql.
304            if expression_type is type(expression):
305                raise ValueError(
306                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
307                )
308
309            return transforms_handler(self, expression)
310
311        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
312
313    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
14def unalias_group(expression: exp.Expression) -> exp.Expression:
15    """
16    Replace references to select aliases in GROUP BY clauses.
17
18    Example:
19        >>> import sqlglot
20        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
21        'SELECT a AS b FROM x GROUP BY 1'
22
23    Args:
24        expression: the expression that will be transformed.
25
26    Returns:
27        The transformed expression.
28    """
29    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
30        aliased_selects = {
31            e.alias: i
32            for i, e in enumerate(expression.parent.expressions, start=1)
33            if isinstance(e, exp.Alias)
34        }
35
36        for group_by in expression.expressions:
37            if (
38                isinstance(group_by, exp.Column)
39                and not group_by.table
40                and group_by.name in aliased_selects
41            ):
42                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
43
44    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
47def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
48    """
49    Convert SELECT DISTINCT ON statements to a subquery with a window function.
50
51    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
52
53    Args:
54        expression: the expression that will be transformed.
55
56    Returns:
57        The transformed expression.
58    """
59    if (
60        isinstance(expression, exp.Select)
61        and expression.args.get("distinct")
62        and expression.args["distinct"].args.get("on")
63        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
64    ):
65        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
66        outer_selects = expression.selects
67        row_number = find_new_name(expression.named_selects, "_row_number")
68        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
69        order = expression.args.get("order")
70
71        if order:
72            window.set("order", order.pop().copy())
73
74        window = exp.alias_(window, row_number)
75        expression.select(window, copy=False)
76
77        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
78
79    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 82def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 83    """
 84    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 85
 86    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 87    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 88
 89    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 90    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 91    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 92    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 93    """
 94    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 95        taken = set(expression.named_selects)
 96        for select in expression.selects:
 97            if not select.alias_or_name:
 98                alias = find_new_name(taken, "_c")
 99                select.replace(exp.alias_(select, alias))
100                taken.add(alias)
101
102        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
103        qualify_filters = expression.args["qualify"].pop().this
104
105        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
106            if isinstance(expr, exp.Window):
107                alias = find_new_name(expression.named_selects, "_w")
108                expression.select(exp.alias_(expr, alias), copy=False)
109                column = exp.column(alias)
110
111                if isinstance(expr.parent, exp.Qualify):
112                    qualify_filters = column
113                else:
114                    expr.replace(column)
115            elif expr.name not in expression.named_selects:
116                expression.select(expr.copy(), copy=False)
117
118        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
119
120    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
123def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
124    """
125    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
126    other expressions. This transforms removes the precision from parameterized types in expressions.
127    """
128    for node in expression.find_all(exp.DataType):
129        node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
130
131    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
134def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
135    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
136    if isinstance(expression, exp.Select):
137        for join in expression.args.get("joins") or []:
138            unnest = join.this
139
140            if isinstance(unnest, exp.Unnest):
141                alias = unnest.args.get("alias")
142                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
143
144                expression.args["joins"].remove(join)
145
146                for e, column in zip(unnest.expressions, alias.columns if alias else []):
147                    expression.append(
148                        "laterals",
149                        exp.Lateral(
150                            this=udtf(this=e),
151                            view=True,
152                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
153                        ),
154                    )
155
156    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
159def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
160    """Convert explode/posexplode into unnest (used in hive -> presto)."""
161    if isinstance(expression, exp.Select):
162        from sqlglot.optimizer.scope import build_scope
163
164        taken_select_names = set(expression.named_selects)
165        scope = build_scope(expression)
166        if not scope:
167            return expression
168        taken_source_names = set(scope.selected_sources)
169
170        for select in expression.selects:
171            to_replace = select
172
173            pos_alias = ""
174            explode_alias = ""
175
176            if isinstance(select, exp.Alias):
177                explode_alias = select.alias
178                select = select.this
179            elif isinstance(select, exp.Aliases):
180                pos_alias = select.aliases[0].name
181                explode_alias = select.aliases[1].name
182                select = select.this
183
184            if isinstance(select, (exp.Explode, exp.Posexplode)):
185                is_posexplode = isinstance(select, exp.Posexplode)
186
187                explode_arg = select.this
188                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
189
190                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
191                if isinstance(explode_arg, exp.Column):
192                    taken_select_names.add(explode_arg.output_name)
193
194                unnest_source_alias = find_new_name(taken_source_names, "_u")
195                taken_source_names.add(unnest_source_alias)
196
197                if not explode_alias:
198                    explode_alias = find_new_name(taken_select_names, "col")
199                    taken_select_names.add(explode_alias)
200
201                    if is_posexplode:
202                        pos_alias = find_new_name(taken_select_names, "pos")
203                        taken_select_names.add(pos_alias)
204
205                if is_posexplode:
206                    column_names = [explode_alias, pos_alias]
207                    to_replace.pop()
208                    expression.select(pos_alias, explode_alias, copy=False)
209                else:
210                    column_names = [explode_alias]
211                    to_replace.replace(exp.column(explode_alias))
212
213                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
214
215                if not expression.args.get("from"):
216                    expression.from_(unnest, copy=False)
217                else:
218                    expression.join(unnest, join_type="CROSS", copy=False)
219
220    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_target_from_merge( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
223def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
224    """Remove table refs from columns in when statements."""
225    if isinstance(expression, exp.Merge):
226        alias = expression.this.args.get("alias")
227        targets = {expression.this.this}
228        if alias:
229            targets.add(alias.this)
230
231        for when in expression.expressions:
232            when.transform(
233                lambda node: exp.column(node.name)
234                if isinstance(node, exp.Column) and node.args.get("table") in targets
235                else node,
236                copy=False,
237            )
238
239    return expression

Remove table refs from columns in when statements.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
242def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
243    if (
244        isinstance(expression, exp.WithinGroup)
245        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
246        and isinstance(expression.expression, exp.Order)
247    ):
248        quantile = expression.this.this
249        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
250        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
251
252    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
255def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
256    if isinstance(expression, exp.With) and expression.recursive:
257        sequence = itertools.count()
258        next_name = lambda: f"_c_{next(sequence)}"
259
260        for cte in expression.expressions:
261            if not cte.args["alias"].columns:
262                query = cte.this
263                if isinstance(query, exp.Union):
264                    query = query.this
265
266                cte.args["alias"].set(
267                    "columns",
268                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
269                )
270
271    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
274def preprocess(
275    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
276) -> t.Callable[[Generator, exp.Expression], str]:
277    """
278    Creates a new transform by chaining a sequence of transformations and converts the resulting
279    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
280    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
281
282    Args:
283        transforms: sequence of transform functions. These will be called in order.
284
285    Returns:
286        Function that can be used as a generator transform.
287    """
288
289    def _to_sql(self, expression: exp.Expression) -> str:
290        expression_type = type(expression)
291
292        expression = transforms[0](expression.copy())
293        for t in transforms[1:]:
294            expression = t(expression)
295
296        _sql_handler = getattr(self, expression.key + "_sql", None)
297        if _sql_handler:
298            return _sql_handler(expression)
299
300        transforms_handler = self.TRANSFORMS.get(type(expression))
301        if transforms_handler:
302            # Ensures we don't enter an infinite loop. This can happen when the original expression
303            # has the same type as the final expression and there's no _sql method available for it,
304            # because then it'd re-enter _to_sql.
305            if expression_type is type(expression):
306                raise ValueError(
307                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
308                )
309
310            return transforms_handler(self, expression)
311
312        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
313
314    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.