Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop().copy())
 71
 72        window = exp.alias_(window, row_number)
 73        expression.select(window, copy=False)
 74
 75        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
 76
 77    return expression
 78
 79
 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 81    """
 82    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 83
 84    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 85    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 86
 87    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 88    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 89    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 90    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 91    """
 92    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 93        taken = set(expression.named_selects)
 94        for select in expression.selects:
 95            if not select.alias_or_name:
 96                alias = find_new_name(taken, "_c")
 97                select.replace(exp.alias_(select, alias))
 98                taken.add(alias)
 99
100        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
101        qualify_filters = expression.args["qualify"].pop().this
102
103        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
104            if isinstance(expr, exp.Window):
105                alias = find_new_name(expression.named_selects, "_w")
106                expression.select(exp.alias_(expr, alias), copy=False)
107                column = exp.column(alias)
108
109                if isinstance(expr.parent, exp.Qualify):
110                    qualify_filters = column
111                else:
112                    expr.replace(column)
113            elif expr.name not in expression.named_selects:
114                expression.select(expr.copy(), copy=False)
115
116        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
117
118    return expression
119
120
121def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
122    """
123    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
124    other expressions. This transforms removes the precision from parameterized types in expressions.
125    """
126    for node in expression.find_all(exp.DataType):
127        node.set(
128            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)]
129        )
130
131    return expression
132
133
134def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
135    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
136    if isinstance(expression, exp.Select):
137        for join in expression.args.get("joins") or []:
138            unnest = join.this
139
140            if isinstance(unnest, exp.Unnest):
141                alias = unnest.args.get("alias")
142                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
143
144                expression.args["joins"].remove(join)
145
146                for e, column in zip(unnest.expressions, alias.columns if alias else []):
147                    expression.append(
148                        "laterals",
149                        exp.Lateral(
150                            this=udtf(this=e),
151                            view=True,
152                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
153                        ),
154                    )
155
156    return expression
157
158
159def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
160    """Convert explode/posexplode into unnest (used in hive -> presto)."""
161    if isinstance(expression, exp.Select):
162        from sqlglot.optimizer.scope import Scope
163
164        taken_select_names = set(expression.named_selects)
165        taken_source_names = {name for name, _ in Scope(expression).references}
166
167        for select in expression.selects:
168            to_replace = select
169
170            pos_alias = ""
171            explode_alias = ""
172
173            if isinstance(select, exp.Alias):
174                explode_alias = select.alias
175                select = select.this
176            elif isinstance(select, exp.Aliases):
177                pos_alias = select.aliases[0].name
178                explode_alias = select.aliases[1].name
179                select = select.this
180
181            if isinstance(select, (exp.Explode, exp.Posexplode)):
182                is_posexplode = isinstance(select, exp.Posexplode)
183
184                explode_arg = select.this
185                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
186
187                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
188                if isinstance(explode_arg, exp.Column):
189                    taken_select_names.add(explode_arg.output_name)
190
191                unnest_source_alias = find_new_name(taken_source_names, "_u")
192                taken_source_names.add(unnest_source_alias)
193
194                if not explode_alias:
195                    explode_alias = find_new_name(taken_select_names, "col")
196                    taken_select_names.add(explode_alias)
197
198                    if is_posexplode:
199                        pos_alias = find_new_name(taken_select_names, "pos")
200                        taken_select_names.add(pos_alias)
201
202                if is_posexplode:
203                    column_names = [explode_alias, pos_alias]
204                    to_replace.pop()
205                    expression.select(pos_alias, explode_alias, copy=False)
206                else:
207                    column_names = [explode_alias]
208                    to_replace.replace(exp.column(explode_alias))
209
210                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
211
212                if not expression.args.get("from"):
213                    expression.from_(unnest, copy=False)
214                else:
215                    expression.join(unnest, join_type="CROSS", copy=False)
216
217    return expression
218
219
220def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
221    if (
222        isinstance(expression, exp.WithinGroup)
223        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
224        and isinstance(expression.expression, exp.Order)
225    ):
226        quantile = expression.this.this
227        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
228        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
229
230    return expression
231
232
233def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
234    if isinstance(expression, exp.With) and expression.recursive:
235        next_name = name_sequence("_c_")
236
237        for cte in expression.expressions:
238            if not cte.args["alias"].columns:
239                query = cte.this
240                if isinstance(query, exp.Union):
241                    query = query.this
242
243                cte.args["alias"].set(
244                    "columns",
245                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
246                )
247
248    return expression
249
250
251def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
252    if (
253        isinstance(expression, (exp.Cast, exp.TryCast))
254        and expression.name.lower() == "epoch"
255        and expression.to.this in exp.DataType.TEMPORAL_TYPES
256    ):
257        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
258
259    return expression
260
261
262def preprocess(
263    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
264) -> t.Callable[[Generator, exp.Expression], str]:
265    """
266    Creates a new transform by chaining a sequence of transformations and converts the resulting
267    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
268    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
269
270    Args:
271        transforms: sequence of transform functions. These will be called in order.
272
273    Returns:
274        Function that can be used as a generator transform.
275    """
276
277    def _to_sql(self, expression: exp.Expression) -> str:
278        expression_type = type(expression)
279
280        expression = transforms[0](expression.copy())
281        for t in transforms[1:]:
282            expression = t(expression)
283
284        _sql_handler = getattr(self, expression.key + "_sql", None)
285        if _sql_handler:
286            return _sql_handler(expression)
287
288        transforms_handler = self.TRANSFORMS.get(type(expression))
289        if transforms_handler:
290            # Ensures we don't enter an infinite loop. This can happen when the original expression
291            # has the same type as the final expression and there's no _sql method available for it,
292            # because then it'd re-enter _to_sql.
293            if expression_type is type(expression):
294                raise ValueError(
295                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
296                )
297
298            return transforms_handler(self, expression)
299
300        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
301
302    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop().copy())
72
73        window = exp.alias_(window, row_number)
74        expression.select(window, copy=False)
75
76        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
77
78    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 81def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 82    """
 83    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 84
 85    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 86    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 87
 88    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 89    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 90    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 91    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 92    """
 93    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 94        taken = set(expression.named_selects)
 95        for select in expression.selects:
 96            if not select.alias_or_name:
 97                alias = find_new_name(taken, "_c")
 98                select.replace(exp.alias_(select, alias))
 99                taken.add(alias)
100
101        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
102        qualify_filters = expression.args["qualify"].pop().this
103
104        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
105            if isinstance(expr, exp.Window):
106                alias = find_new_name(expression.named_selects, "_w")
107                expression.select(exp.alias_(expr, alias), copy=False)
108                column = exp.column(alias)
109
110                if isinstance(expr.parent, exp.Qualify):
111                    qualify_filters = column
112                else:
113                    expr.replace(column)
114            elif expr.name not in expression.named_selects:
115                expression.select(expr.copy(), copy=False)
116
117        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
118
119    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
123    """
124    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
125    other expressions. This transforms removes the precision from parameterized types in expressions.
126    """
127    for node in expression.find_all(exp.DataType):
128        node.set(
129            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)]
130        )
131
132    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
136    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
137    if isinstance(expression, exp.Select):
138        for join in expression.args.get("joins") or []:
139            unnest = join.this
140
141            if isinstance(unnest, exp.Unnest):
142                alias = unnest.args.get("alias")
143                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
144
145                expression.args["joins"].remove(join)
146
147                for e, column in zip(unnest.expressions, alias.columns if alias else []):
148                    expression.append(
149                        "laterals",
150                        exp.Lateral(
151                            this=udtf(this=e),
152                            view=True,
153                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
154                        ),
155                    )
156
157    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
160def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
161    """Convert explode/posexplode into unnest (used in hive -> presto)."""
162    if isinstance(expression, exp.Select):
163        from sqlglot.optimizer.scope import Scope
164
165        taken_select_names = set(expression.named_selects)
166        taken_source_names = {name for name, _ in Scope(expression).references}
167
168        for select in expression.selects:
169            to_replace = select
170
171            pos_alias = ""
172            explode_alias = ""
173
174            if isinstance(select, exp.Alias):
175                explode_alias = select.alias
176                select = select.this
177            elif isinstance(select, exp.Aliases):
178                pos_alias = select.aliases[0].name
179                explode_alias = select.aliases[1].name
180                select = select.this
181
182            if isinstance(select, (exp.Explode, exp.Posexplode)):
183                is_posexplode = isinstance(select, exp.Posexplode)
184
185                explode_arg = select.this
186                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
187
188                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
189                if isinstance(explode_arg, exp.Column):
190                    taken_select_names.add(explode_arg.output_name)
191
192                unnest_source_alias = find_new_name(taken_source_names, "_u")
193                taken_source_names.add(unnest_source_alias)
194
195                if not explode_alias:
196                    explode_alias = find_new_name(taken_select_names, "col")
197                    taken_select_names.add(explode_alias)
198
199                    if is_posexplode:
200                        pos_alias = find_new_name(taken_select_names, "pos")
201                        taken_select_names.add(pos_alias)
202
203                if is_posexplode:
204                    column_names = [explode_alias, pos_alias]
205                    to_replace.pop()
206                    expression.select(pos_alias, explode_alias, copy=False)
207                else:
208                    column_names = [explode_alias]
209                    to_replace.replace(exp.column(explode_alias))
210
211                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
212
213                if not expression.args.get("from"):
214                    expression.from_(unnest, copy=False)
215                else:
216                    expression.join(unnest, join_type="CROSS", copy=False)
217
218    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
221def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
222    if (
223        isinstance(expression, exp.WithinGroup)
224        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
225        and isinstance(expression.expression, exp.Order)
226    ):
227        quantile = expression.this.this
228        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
229        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
230
231    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
234def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
235    if isinstance(expression, exp.With) and expression.recursive:
236        next_name = name_sequence("_c_")
237
238        for cte in expression.expressions:
239            if not cte.args["alias"].columns:
240                query = cte.this
241                if isinstance(query, exp.Union):
242                    query = query.this
243
244                cte.args["alias"].set(
245                    "columns",
246                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
247                )
248
249    return expression
def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
252def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
253    if (
254        isinstance(expression, (exp.Cast, exp.TryCast))
255        and expression.name.lower() == "epoch"
256        and expression.to.this in exp.DataType.TEMPORAL_TYPES
257    ):
258        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
259
260    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
263def preprocess(
264    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
265) -> t.Callable[[Generator, exp.Expression], str]:
266    """
267    Creates a new transform by chaining a sequence of transformations and converts the resulting
268    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
269    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
270
271    Args:
272        transforms: sequence of transform functions. These will be called in order.
273
274    Returns:
275        Function that can be used as a generator transform.
276    """
277
278    def _to_sql(self, expression: exp.Expression) -> str:
279        expression_type = type(expression)
280
281        expression = transforms[0](expression.copy())
282        for t in transforms[1:]:
283            expression = t(expression)
284
285        _sql_handler = getattr(self, expression.key + "_sql", None)
286        if _sql_handler:
287            return _sql_handler(expression)
288
289        transforms_handler = self.TRANSFORMS.get(type(expression))
290        if transforms_handler:
291            # Ensures we don't enter an infinite loop. This can happen when the original expression
292            # has the same type as the final expression and there's no _sql method available for it,
293            # because then it'd re-enter _to_sql.
294            if expression_type is type(expression):
295                raise ValueError(
296                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
297                )
298
299            return transforms_handler(self, expression)
300
301        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
302
303    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.