Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop().copy())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects)
 79            .from_(expression.subquery("_t"))
 80            .where(exp.column(row_number).eq(1))
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
124
125    return expression
126
127
128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
129    """
130    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
131    other expressions. This transforms removes the precision from parameterized types in expressions.
132    """
133    for node in expression.find_all(exp.DataType):
134        node.set(
135            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
136        )
137
138    return expression
139
140
141def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
142    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
143    if isinstance(expression, exp.Select):
144        for join in expression.args.get("joins") or []:
145            unnest = join.this
146
147            if isinstance(unnest, exp.Unnest):
148                alias = unnest.args.get("alias")
149                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
150
151                expression.args["joins"].remove(join)
152
153                for e, column in zip(unnest.expressions, alias.columns if alias else []):
154                    expression.append(
155                        "laterals",
156                        exp.Lateral(
157                            this=udtf(this=e),
158                            view=True,
159                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
160                        ),
161                    )
162
163    return expression
164
165
166def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
167    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
168        """Convert explode/posexplode into unnest (used in hive -> presto)."""
169        if isinstance(expression, exp.Select):
170            from sqlglot.optimizer.scope import Scope
171
172            taken_select_names = set(expression.named_selects)
173            taken_source_names = {name for name, _ in Scope(expression).references}
174
175            def new_name(names: t.Set[str], name: str) -> str:
176                name = find_new_name(names, name)
177                names.add(name)
178                return name
179
180            arrays: t.List[exp.Condition] = []
181            series_alias = new_name(taken_select_names, "pos")
182            series = exp.alias_(
183                exp.Unnest(
184                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
185                ),
186                new_name(taken_source_names, "_u"),
187                table=[series_alias],
188            )
189
190            # we use list here because expression.selects is mutated inside the loop
191            for select in expression.selects.copy():
192                explode = select.find(exp.Explode)
193
194                if explode:
195                    pos_alias = ""
196                    explode_alias = ""
197
198                    if isinstance(select, exp.Alias):
199                        explode_alias = select.alias
200                        alias = select
201                    elif isinstance(select, exp.Aliases):
202                        pos_alias = select.aliases[0].name
203                        explode_alias = select.aliases[1].name
204                        alias = select.replace(exp.alias_(select.this, "", copy=False))
205                    else:
206                        alias = select.replace(exp.alias_(select, ""))
207                        explode = alias.find(exp.Explode)
208                        assert explode
209
210                    is_posexplode = isinstance(explode, exp.Posexplode)
211                    explode_arg = explode.this
212
213                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
214                    if isinstance(explode_arg, exp.Column):
215                        taken_select_names.add(explode_arg.output_name)
216
217                    unnest_source_alias = new_name(taken_source_names, "_u")
218
219                    if not explode_alias:
220                        explode_alias = new_name(taken_select_names, "col")
221
222                        if is_posexplode:
223                            pos_alias = new_name(taken_select_names, "pos")
224
225                    if not pos_alias:
226                        pos_alias = new_name(taken_select_names, "pos")
227
228                    alias.set("alias", exp.to_identifier(explode_alias))
229
230                    column = exp.If(
231                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
232                        true=exp.column(explode_alias),
233                    )
234
235                    explode.replace(column)
236
237                    if is_posexplode:
238                        expressions = expression.expressions
239                        expressions.insert(
240                            expressions.index(alias) + 1,
241                            exp.If(
242                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
243                                true=exp.column(pos_alias),
244                            ).as_(pos_alias),
245                        )
246                        expression.set("expressions", expressions)
247
248                    if not arrays:
249                        if expression.args.get("from"):
250                            expression.join(series, copy=False)
251                        else:
252                            expression.from_(series, copy=False)
253
254                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
255                    arrays.append(size)
256
257                    # trino doesn't support left join unnest with on conditions
258                    # if it did, this would be much simpler
259                    expression.join(
260                        exp.alias_(
261                            exp.Unnest(
262                                expressions=[explode_arg.copy()],
263                                offset=exp.to_identifier(pos_alias),
264                            ),
265                            unnest_source_alias,
266                            table=[explode_alias],
267                        ),
268                        join_type="CROSS",
269                        copy=False,
270                    )
271
272                    if index_offset != 1:
273                        size = size - 1
274
275                    expression.where(
276                        exp.column(series_alias)
277                        .eq(exp.column(pos_alias))
278                        .or_(
279                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
280                        ),
281                        copy=False,
282                    )
283
284            if arrays:
285                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
286
287                if index_offset != 1:
288                    end = end - (1 - index_offset)
289                series.expressions[0].set("end", end)
290
291        return expression
292
293    return _explode_to_unnest
294
295
296PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
297
298
299def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
300    if (
301        isinstance(expression, PERCENTILES)
302        and not isinstance(expression.parent, exp.WithinGroup)
303        and expression.expression
304    ):
305        column = expression.this.pop()
306        expression.set("this", expression.expression.pop())
307        order = exp.Order(expressions=[exp.Ordered(this=column)])
308        expression = exp.WithinGroup(this=expression, expression=order)
309
310    return expression
311
312
313def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
314    if (
315        isinstance(expression, exp.WithinGroup)
316        and isinstance(expression.this, PERCENTILES)
317        and isinstance(expression.expression, exp.Order)
318    ):
319        quantile = expression.this.this
320        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
321        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
322
323    return expression
324
325
326def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
327    if isinstance(expression, exp.With) and expression.recursive:
328        next_name = name_sequence("_c_")
329
330        for cte in expression.expressions:
331            if not cte.args["alias"].columns:
332                query = cte.this
333                if isinstance(query, exp.Union):
334                    query = query.this
335
336                cte.args["alias"].set(
337                    "columns",
338                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
339                )
340
341    return expression
342
343
344def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
345    if (
346        isinstance(expression, (exp.Cast, exp.TryCast))
347        and expression.name.lower() == "epoch"
348        and expression.to.this in exp.DataType.TEMPORAL_TYPES
349    ):
350        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
351
352    return expression
353
354
355def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
356    if isinstance(expression, exp.Select):
357        for join in expression.args.get("joins") or []:
358            on = join.args.get("on")
359            if on and join.kind in ("SEMI", "ANTI"):
360                subquery = exp.select("1").from_(join.this).where(on)
361                exists = exp.Exists(this=subquery)
362                if join.kind == "ANTI":
363                    exists = exists.not_(copy=False)
364
365                join.pop()
366                expression.where(exists, copy=False)
367
368    return expression
369
370
371def preprocess(
372    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
373) -> t.Callable[[Generator, exp.Expression], str]:
374    """
375    Creates a new transform by chaining a sequence of transformations and converts the resulting
376    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
377    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
378
379    Args:
380        transforms: sequence of transform functions. These will be called in order.
381
382    Returns:
383        Function that can be used as a generator transform.
384    """
385
386    def _to_sql(self, expression: exp.Expression) -> str:
387        expression_type = type(expression)
388
389        expression = transforms[0](expression.copy())
390        for t in transforms[1:]:
391            expression = t(expression)
392
393        _sql_handler = getattr(self, expression.key + "_sql", None)
394        if _sql_handler:
395            return _sql_handler(expression)
396
397        transforms_handler = self.TRANSFORMS.get(type(expression))
398        if transforms_handler:
399            if expression_type is type(expression):
400                if isinstance(expression, exp.Func):
401                    return self.function_fallback_sql(expression)
402
403                # Ensures we don't enter an infinite loop. This can happen when the original expression
404                # has the same type as the final expression and there's no _sql method available for it,
405                # because then it'd re-enter _to_sql.
406                raise ValueError(
407                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
408                )
409
410            return transforms_handler(self, expression)
411
412        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
413
414    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop().copy())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects)
80            .from_(expression.subquery("_t"))
81            .where(exp.column(row_number).eq(1))
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
125
126    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
130    """
131    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
132    other expressions. This transforms removes the precision from parameterized types in expressions.
133    """
134    for node in expression.find_all(exp.DataType):
135        node.set(
136            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
137        )
138
139    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
143    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
144    if isinstance(expression, exp.Select):
145        for join in expression.args.get("joins") or []:
146            unnest = join.this
147
148            if isinstance(unnest, exp.Unnest):
149                alias = unnest.args.get("alias")
150                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
151
152                expression.args["joins"].remove(join)
153
154                for e, column in zip(unnest.expressions, alias.columns if alias else []):
155                    expression.append(
156                        "laterals",
157                        exp.Lateral(
158                            this=udtf(this=e),
159                            view=True,
160                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
161                        ),
162                    )
163
164    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
167def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
168    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
169        """Convert explode/posexplode into unnest (used in hive -> presto)."""
170        if isinstance(expression, exp.Select):
171            from sqlglot.optimizer.scope import Scope
172
173            taken_select_names = set(expression.named_selects)
174            taken_source_names = {name for name, _ in Scope(expression).references}
175
176            def new_name(names: t.Set[str], name: str) -> str:
177                name = find_new_name(names, name)
178                names.add(name)
179                return name
180
181            arrays: t.List[exp.Condition] = []
182            series_alias = new_name(taken_select_names, "pos")
183            series = exp.alias_(
184                exp.Unnest(
185                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
186                ),
187                new_name(taken_source_names, "_u"),
188                table=[series_alias],
189            )
190
191            # we use list here because expression.selects is mutated inside the loop
192            for select in expression.selects.copy():
193                explode = select.find(exp.Explode)
194
195                if explode:
196                    pos_alias = ""
197                    explode_alias = ""
198
199                    if isinstance(select, exp.Alias):
200                        explode_alias = select.alias
201                        alias = select
202                    elif isinstance(select, exp.Aliases):
203                        pos_alias = select.aliases[0].name
204                        explode_alias = select.aliases[1].name
205                        alias = select.replace(exp.alias_(select.this, "", copy=False))
206                    else:
207                        alias = select.replace(exp.alias_(select, ""))
208                        explode = alias.find(exp.Explode)
209                        assert explode
210
211                    is_posexplode = isinstance(explode, exp.Posexplode)
212                    explode_arg = explode.this
213
214                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
215                    if isinstance(explode_arg, exp.Column):
216                        taken_select_names.add(explode_arg.output_name)
217
218                    unnest_source_alias = new_name(taken_source_names, "_u")
219
220                    if not explode_alias:
221                        explode_alias = new_name(taken_select_names, "col")
222
223                        if is_posexplode:
224                            pos_alias = new_name(taken_select_names, "pos")
225
226                    if not pos_alias:
227                        pos_alias = new_name(taken_select_names, "pos")
228
229                    alias.set("alias", exp.to_identifier(explode_alias))
230
231                    column = exp.If(
232                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
233                        true=exp.column(explode_alias),
234                    )
235
236                    explode.replace(column)
237
238                    if is_posexplode:
239                        expressions = expression.expressions
240                        expressions.insert(
241                            expressions.index(alias) + 1,
242                            exp.If(
243                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
244                                true=exp.column(pos_alias),
245                            ).as_(pos_alias),
246                        )
247                        expression.set("expressions", expressions)
248
249                    if not arrays:
250                        if expression.args.get("from"):
251                            expression.join(series, copy=False)
252                        else:
253                            expression.from_(series, copy=False)
254
255                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
256                    arrays.append(size)
257
258                    # trino doesn't support left join unnest with on conditions
259                    # if it did, this would be much simpler
260                    expression.join(
261                        exp.alias_(
262                            exp.Unnest(
263                                expressions=[explode_arg.copy()],
264                                offset=exp.to_identifier(pos_alias),
265                            ),
266                            unnest_source_alias,
267                            table=[explode_alias],
268                        ),
269                        join_type="CROSS",
270                        copy=False,
271                    )
272
273                    if index_offset != 1:
274                        size = size - 1
275
276                    expression.where(
277                        exp.column(series_alias)
278                        .eq(exp.column(pos_alias))
279                        .or_(
280                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
281                        ),
282                        copy=False,
283                    )
284
285            if arrays:
286                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
287
288                if index_offset != 1:
289                    end = end - (1 - index_offset)
290                series.expressions[0].set("end", end)
291
292        return expression
293
294    return _explode_to_unnest
def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
300def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
301    if (
302        isinstance(expression, PERCENTILES)
303        and not isinstance(expression.parent, exp.WithinGroup)
304        and expression.expression
305    ):
306        column = expression.this.pop()
307        expression.set("this", expression.expression.pop())
308        order = exp.Order(expressions=[exp.Ordered(this=column)])
309        expression = exp.WithinGroup(this=expression, expression=order)
310
311    return expression
def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
314def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
315    if (
316        isinstance(expression, exp.WithinGroup)
317        and isinstance(expression.this, PERCENTILES)
318        and isinstance(expression.expression, exp.Order)
319    ):
320        quantile = expression.this.this
321        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
322        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
323
324    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
327def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
328    if isinstance(expression, exp.With) and expression.recursive:
329        next_name = name_sequence("_c_")
330
331        for cte in expression.expressions:
332            if not cte.args["alias"].columns:
333                query = cte.this
334                if isinstance(query, exp.Union):
335                    query = query.this
336
337                cte.args["alias"].set(
338                    "columns",
339                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
340                )
341
342    return expression
def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
345def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
346    if (
347        isinstance(expression, (exp.Cast, exp.TryCast))
348        and expression.name.lower() == "epoch"
349        and expression.to.this in exp.DataType.TEMPORAL_TYPES
350    ):
351        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
352
353    return expression
def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
356def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
357    if isinstance(expression, exp.Select):
358        for join in expression.args.get("joins") or []:
359            on = join.args.get("on")
360            if on and join.kind in ("SEMI", "ANTI"):
361                subquery = exp.select("1").from_(join.this).where(on)
362                exists = exp.Exists(this=subquery)
363                if join.kind == "ANTI":
364                    exists = exists.not_(copy=False)
365
366                join.pop()
367                expression.where(exists, copy=False)
368
369    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
372def preprocess(
373    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
374) -> t.Callable[[Generator, exp.Expression], str]:
375    """
376    Creates a new transform by chaining a sequence of transformations and converts the resulting
377    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
378    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
379
380    Args:
381        transforms: sequence of transform functions. These will be called in order.
382
383    Returns:
384        Function that can be used as a generator transform.
385    """
386
387    def _to_sql(self, expression: exp.Expression) -> str:
388        expression_type = type(expression)
389
390        expression = transforms[0](expression.copy())
391        for t in transforms[1:]:
392            expression = t(expression)
393
394        _sql_handler = getattr(self, expression.key + "_sql", None)
395        if _sql_handler:
396            return _sql_handler(expression)
397
398        transforms_handler = self.TRANSFORMS.get(type(expression))
399        if transforms_handler:
400            if expression_type is type(expression):
401                if isinstance(expression, exp.Func):
402                    return self.function_fallback_sql(expression)
403
404                # Ensures we don't enter an infinite loop. This can happen when the original expression
405                # has the same type as the final expression and there's no _sql method available for it,
406                # because then it'd re-enter _to_sql.
407                raise ValueError(
408                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
409                )
410
411            return transforms_handler(self, expression)
412
413        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
414
415    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.