Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop().copy())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects)
 79            .from_(expression.subquery())
 80            .where(exp.column(row_number).eq(1))
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
124
125    return expression
126
127
128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
129    """
130    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
131    other expressions. This transforms removes the precision from parameterized types in expressions.
132    """
133    for node in expression.find_all(exp.DataType):
134        node.set(
135            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
136        )
137
138    return expression
139
140
141def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
142    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
143    if isinstance(expression, exp.Select):
144        for join in expression.args.get("joins") or []:
145            unnest = join.this
146
147            if isinstance(unnest, exp.Unnest):
148                alias = unnest.args.get("alias")
149                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
150
151                expression.args["joins"].remove(join)
152
153                for e, column in zip(unnest.expressions, alias.columns if alias else []):
154                    expression.append(
155                        "laterals",
156                        exp.Lateral(
157                            this=udtf(this=e),
158                            view=True,
159                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
160                        ),
161                    )
162
163    return expression
164
165
166def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
167    """Convert explode/posexplode into unnest (used in hive -> presto)."""
168    if isinstance(expression, exp.Select):
169        from sqlglot.optimizer.scope import Scope
170
171        taken_select_names = set(expression.named_selects)
172        taken_source_names = {name for name, _ in Scope(expression).references}
173
174        for select in expression.selects:
175            to_replace = select
176
177            pos_alias = ""
178            explode_alias = ""
179
180            if isinstance(select, exp.Alias):
181                explode_alias = select.alias
182                select = select.this
183            elif isinstance(select, exp.Aliases):
184                pos_alias = select.aliases[0].name
185                explode_alias = select.aliases[1].name
186                select = select.this
187
188            if isinstance(select, (exp.Explode, exp.Posexplode)):
189                is_posexplode = isinstance(select, exp.Posexplode)
190
191                explode_arg = select.this
192                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
193
194                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
195                if isinstance(explode_arg, exp.Column):
196                    taken_select_names.add(explode_arg.output_name)
197
198                unnest_source_alias = find_new_name(taken_source_names, "_u")
199                taken_source_names.add(unnest_source_alias)
200
201                if not explode_alias:
202                    explode_alias = find_new_name(taken_select_names, "col")
203                    taken_select_names.add(explode_alias)
204
205                    if is_posexplode:
206                        pos_alias = find_new_name(taken_select_names, "pos")
207                        taken_select_names.add(pos_alias)
208
209                if is_posexplode:
210                    column_names = [explode_alias, pos_alias]
211                    to_replace.pop()
212                    expression.select(pos_alias, explode_alias, copy=False)
213                else:
214                    column_names = [explode_alias]
215                    to_replace.replace(exp.column(explode_alias))
216
217                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
218
219                if not expression.args.get("from"):
220                    expression.from_(unnest, copy=False)
221                else:
222                    expression.join(unnest, join_type="CROSS", copy=False)
223
224    return expression
225
226
227def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
228    if (
229        isinstance(expression, exp.WithinGroup)
230        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
231        and isinstance(expression.expression, exp.Order)
232    ):
233        quantile = expression.this.this
234        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
235        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
236
237    return expression
238
239
240def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
241    if isinstance(expression, exp.With) and expression.recursive:
242        next_name = name_sequence("_c_")
243
244        for cte in expression.expressions:
245            if not cte.args["alias"].columns:
246                query = cte.this
247                if isinstance(query, exp.Union):
248                    query = query.this
249
250                cte.args["alias"].set(
251                    "columns",
252                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
253                )
254
255    return expression
256
257
258def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
259    if (
260        isinstance(expression, (exp.Cast, exp.TryCast))
261        and expression.name.lower() == "epoch"
262        and expression.to.this in exp.DataType.TEMPORAL_TYPES
263    ):
264        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
265
266    return expression
267
268
269def preprocess(
270    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
271) -> t.Callable[[Generator, exp.Expression], str]:
272    """
273    Creates a new transform by chaining a sequence of transformations and converts the resulting
274    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
275    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
276
277    Args:
278        transforms: sequence of transform functions. These will be called in order.
279
280    Returns:
281        Function that can be used as a generator transform.
282    """
283
284    def _to_sql(self, expression: exp.Expression) -> str:
285        expression_type = type(expression)
286
287        expression = transforms[0](expression.copy())
288        for t in transforms[1:]:
289            expression = t(expression)
290
291        _sql_handler = getattr(self, expression.key + "_sql", None)
292        if _sql_handler:
293            return _sql_handler(expression)
294
295        transforms_handler = self.TRANSFORMS.get(type(expression))
296        if transforms_handler:
297            # Ensures we don't enter an infinite loop. This can happen when the original expression
298            # has the same type as the final expression and there's no _sql method available for it,
299            # because then it'd re-enter _to_sql.
300            if expression_type is type(expression):
301                raise ValueError(
302                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
303                )
304
305            return transforms_handler(self, expression)
306
307        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
308
309    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop().copy())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects)
80            .from_(expression.subquery())
81            .where(exp.column(row_number).eq(1))
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
125
126    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
130    """
131    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
132    other expressions. This transforms removes the precision from parameterized types in expressions.
133    """
134    for node in expression.find_all(exp.DataType):
135        node.set(
136            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
137        )
138
139    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
143    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
144    if isinstance(expression, exp.Select):
145        for join in expression.args.get("joins") or []:
146            unnest = join.this
147
148            if isinstance(unnest, exp.Unnest):
149                alias = unnest.args.get("alias")
150                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
151
152                expression.args["joins"].remove(join)
153
154                for e, column in zip(unnest.expressions, alias.columns if alias else []):
155                    expression.append(
156                        "laterals",
157                        exp.Lateral(
158                            this=udtf(this=e),
159                            view=True,
160                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
161                        ),
162                    )
163
164    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
167def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
168    """Convert explode/posexplode into unnest (used in hive -> presto)."""
169    if isinstance(expression, exp.Select):
170        from sqlglot.optimizer.scope import Scope
171
172        taken_select_names = set(expression.named_selects)
173        taken_source_names = {name for name, _ in Scope(expression).references}
174
175        for select in expression.selects:
176            to_replace = select
177
178            pos_alias = ""
179            explode_alias = ""
180
181            if isinstance(select, exp.Alias):
182                explode_alias = select.alias
183                select = select.this
184            elif isinstance(select, exp.Aliases):
185                pos_alias = select.aliases[0].name
186                explode_alias = select.aliases[1].name
187                select = select.this
188
189            if isinstance(select, (exp.Explode, exp.Posexplode)):
190                is_posexplode = isinstance(select, exp.Posexplode)
191
192                explode_arg = select.this
193                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
194
195                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
196                if isinstance(explode_arg, exp.Column):
197                    taken_select_names.add(explode_arg.output_name)
198
199                unnest_source_alias = find_new_name(taken_source_names, "_u")
200                taken_source_names.add(unnest_source_alias)
201
202                if not explode_alias:
203                    explode_alias = find_new_name(taken_select_names, "col")
204                    taken_select_names.add(explode_alias)
205
206                    if is_posexplode:
207                        pos_alias = find_new_name(taken_select_names, "pos")
208                        taken_select_names.add(pos_alias)
209
210                if is_posexplode:
211                    column_names = [explode_alias, pos_alias]
212                    to_replace.pop()
213                    expression.select(pos_alias, explode_alias, copy=False)
214                else:
215                    column_names = [explode_alias]
216                    to_replace.replace(exp.column(explode_alias))
217
218                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
219
220                if not expression.args.get("from"):
221                    expression.from_(unnest, copy=False)
222                else:
223                    expression.join(unnest, join_type="CROSS", copy=False)
224
225    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
228def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
229    if (
230        isinstance(expression, exp.WithinGroup)
231        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
232        and isinstance(expression.expression, exp.Order)
233    ):
234        quantile = expression.this.this
235        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
236        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
237
238    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
241def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
242    if isinstance(expression, exp.With) and expression.recursive:
243        next_name = name_sequence("_c_")
244
245        for cte in expression.expressions:
246            if not cte.args["alias"].columns:
247                query = cte.this
248                if isinstance(query, exp.Union):
249                    query = query.this
250
251                cte.args["alias"].set(
252                    "columns",
253                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
254                )
255
256    return expression
def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
259def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
260    if (
261        isinstance(expression, (exp.Cast, exp.TryCast))
262        and expression.name.lower() == "epoch"
263        and expression.to.this in exp.DataType.TEMPORAL_TYPES
264    ):
265        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
266
267    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
270def preprocess(
271    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
272) -> t.Callable[[Generator, exp.Expression], str]:
273    """
274    Creates a new transform by chaining a sequence of transformations and converts the resulting
275    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
276    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
277
278    Args:
279        transforms: sequence of transform functions. These will be called in order.
280
281    Returns:
282        Function that can be used as a generator transform.
283    """
284
285    def _to_sql(self, expression: exp.Expression) -> str:
286        expression_type = type(expression)
287
288        expression = transforms[0](expression.copy())
289        for t in transforms[1:]:
290            expression = t(expression)
291
292        _sql_handler = getattr(self, expression.key + "_sql", None)
293        if _sql_handler:
294            return _sql_handler(expression)
295
296        transforms_handler = self.TRANSFORMS.get(type(expression))
297        if transforms_handler:
298            # Ensures we don't enter an infinite loop. This can happen when the original expression
299            # has the same type as the final expression and there's no _sql method available for it,
300            # because then it'd re-enter _to_sql.
301            if expression_type is type(expression):
302                raise ValueError(
303                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
304                )
305
306            return transforms_handler(self, expression)
307
308        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
309
310    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.