Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 97    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 98    corresponding expression to avoid creating invalid column references.
 99    """
100    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
101        taken = set(expression.named_selects)
102        for select in expression.selects:
103            if not select.alias_or_name:
104                alias = find_new_name(taken, "_c")
105                select.replace(exp.alias_(select, alias))
106                taken.add(alias)
107
108        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
109            alias_or_name = select.alias_or_name
110            identifier = select.args.get("alias") or select.this
111            if isinstance(identifier, exp.Identifier):
112                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
113            return alias_or_name
114
115        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
116        qualify_filters = expression.args["qualify"].pop().this
117        expression_by_alias = {
118            select.alias: select.this
119            for select in expression.selects
120            if isinstance(select, exp.Alias)
121        }
122
123        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
124        for select_candidate in qualify_filters.find_all(select_candidates):
125            if isinstance(select_candidate, exp.Window):
126                if expression_by_alias:
127                    for column in select_candidate.find_all(exp.Column):
128                        expr = expression_by_alias.get(column.name)
129                        if expr:
130                            column.replace(expr)
131
132                alias = find_new_name(expression.named_selects, "_w")
133                expression.select(exp.alias_(select_candidate, alias), copy=False)
134                column = exp.column(alias)
135
136                if isinstance(select_candidate.parent, exp.Qualify):
137                    qualify_filters = column
138                else:
139                    select_candidate.replace(column)
140            elif select_candidate.name not in expression.named_selects:
141                expression.select(select_candidate.copy(), copy=False)
142
143        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
144            qualify_filters, copy=False
145        )
146
147    return expression
148
149
150def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
151    """
152    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
153    other expressions. This transforms removes the precision from parameterized types in expressions.
154    """
155    for node in expression.find_all(exp.DataType):
156        node.set(
157            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
158        )
159
160    return expression
161
162
163def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
164    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
165    from sqlglot.optimizer.scope import find_all_in_scope
166
167    if isinstance(expression, exp.Select):
168        unnest_aliases = {
169            unnest.alias
170            for unnest in find_all_in_scope(expression, exp.Unnest)
171            if isinstance(unnest.parent, (exp.From, exp.Join))
172        }
173        if unnest_aliases:
174            for column in expression.find_all(exp.Column):
175                if column.table in unnest_aliases:
176                    column.set("table", None)
177                elif column.db in unnest_aliases:
178                    column.set("db", None)
179
180    return expression
181
182
183def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
184    """Convert cross join unnest into lateral view explode."""
185    if isinstance(expression, exp.Select):
186        for join in expression.args.get("joins") or []:
187            unnest = join.this
188
189            if isinstance(unnest, exp.Unnest):
190                alias = unnest.args.get("alias")
191                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
192
193                expression.args["joins"].remove(join)
194
195                for e, column in zip(unnest.expressions, alias.columns if alias else []):
196                    expression.append(
197                        "laterals",
198                        exp.Lateral(
199                            this=udtf(this=e),
200                            view=True,
201                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
202                        ),
203                    )
204
205    return expression
206
207
208def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
209    """Convert explode/posexplode into unnest."""
210
211    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
212        if isinstance(expression, exp.Select):
213            from sqlglot.optimizer.scope import Scope
214
215            taken_select_names = set(expression.named_selects)
216            taken_source_names = {name for name, _ in Scope(expression).references}
217
218            def new_name(names: t.Set[str], name: str) -> str:
219                name = find_new_name(names, name)
220                names.add(name)
221                return name
222
223            arrays: t.List[exp.Condition] = []
224            series_alias = new_name(taken_select_names, "pos")
225            series = exp.alias_(
226                exp.Unnest(
227                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
228                ),
229                new_name(taken_source_names, "_u"),
230                table=[series_alias],
231            )
232
233            # we use list here because expression.selects is mutated inside the loop
234            for select in list(expression.selects):
235                explode = select.find(exp.Explode)
236
237                if explode:
238                    pos_alias = ""
239                    explode_alias = ""
240
241                    if isinstance(select, exp.Alias):
242                        explode_alias = select.args["alias"]
243                        alias = select
244                    elif isinstance(select, exp.Aliases):
245                        pos_alias = select.aliases[0]
246                        explode_alias = select.aliases[1]
247                        alias = select.replace(exp.alias_(select.this, "", copy=False))
248                    else:
249                        alias = select.replace(exp.alias_(select, ""))
250                        explode = alias.find(exp.Explode)
251                        assert explode
252
253                    is_posexplode = isinstance(explode, exp.Posexplode)
254                    explode_arg = explode.this
255
256                    if isinstance(explode, exp.ExplodeOuter):
257                        bracket = explode_arg[0]
258                        bracket.set("safe", True)
259                        bracket.set("offset", True)
260                        explode_arg = exp.func(
261                            "IF",
262                            exp.func(
263                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
264                            ).eq(0),
265                            exp.array(bracket, copy=False),
266                            explode_arg,
267                        )
268
269                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
270                    if isinstance(explode_arg, exp.Column):
271                        taken_select_names.add(explode_arg.output_name)
272
273                    unnest_source_alias = new_name(taken_source_names, "_u")
274
275                    if not explode_alias:
276                        explode_alias = new_name(taken_select_names, "col")
277
278                        if is_posexplode:
279                            pos_alias = new_name(taken_select_names, "pos")
280
281                    if not pos_alias:
282                        pos_alias = new_name(taken_select_names, "pos")
283
284                    alias.set("alias", exp.to_identifier(explode_alias))
285
286                    series_table_alias = series.args["alias"].this
287                    column = exp.If(
288                        this=exp.column(series_alias, table=series_table_alias).eq(
289                            exp.column(pos_alias, table=unnest_source_alias)
290                        ),
291                        true=exp.column(explode_alias, table=unnest_source_alias),
292                    )
293
294                    explode.replace(column)
295
296                    if is_posexplode:
297                        expressions = expression.expressions
298                        expressions.insert(
299                            expressions.index(alias) + 1,
300                            exp.If(
301                                this=exp.column(series_alias, table=series_table_alias).eq(
302                                    exp.column(pos_alias, table=unnest_source_alias)
303                                ),
304                                true=exp.column(pos_alias, table=unnest_source_alias),
305                            ).as_(pos_alias),
306                        )
307                        expression.set("expressions", expressions)
308
309                    if not arrays:
310                        if expression.args.get("from"):
311                            expression.join(series, copy=False, join_type="CROSS")
312                        else:
313                            expression.from_(series, copy=False)
314
315                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
316                    arrays.append(size)
317
318                    # trino doesn't support left join unnest with on conditions
319                    # if it did, this would be much simpler
320                    expression.join(
321                        exp.alias_(
322                            exp.Unnest(
323                                expressions=[explode_arg.copy()],
324                                offset=exp.to_identifier(pos_alias),
325                            ),
326                            unnest_source_alias,
327                            table=[explode_alias],
328                        ),
329                        join_type="CROSS",
330                        copy=False,
331                    )
332
333                    if index_offset != 1:
334                        size = size - 1
335
336                    expression.where(
337                        exp.column(series_alias, table=series_table_alias)
338                        .eq(exp.column(pos_alias, table=unnest_source_alias))
339                        .or_(
340                            (exp.column(series_alias, table=series_table_alias) > size).and_(
341                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
342                            )
343                        ),
344                        copy=False,
345                    )
346
347            if arrays:
348                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
349
350                if index_offset != 1:
351                    end = end - (1 - index_offset)
352                series.expressions[0].set("end", end)
353
354        return expression
355
356    return _explode_to_unnest
357
358
359def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
360    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
361    if (
362        isinstance(expression, exp.PERCENTILES)
363        and not isinstance(expression.parent, exp.WithinGroup)
364        and expression.expression
365    ):
366        column = expression.this.pop()
367        expression.set("this", expression.expression.pop())
368        order = exp.Order(expressions=[exp.Ordered(this=column)])
369        expression = exp.WithinGroup(this=expression, expression=order)
370
371    return expression
372
373
374def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
375    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
376    if (
377        isinstance(expression, exp.WithinGroup)
378        and isinstance(expression.this, exp.PERCENTILES)
379        and isinstance(expression.expression, exp.Order)
380    ):
381        quantile = expression.this.this
382        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
383        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
384
385    return expression
386
387
388def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
389    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
390    if isinstance(expression, exp.With) and expression.recursive:
391        next_name = name_sequence("_c_")
392
393        for cte in expression.expressions:
394            if not cte.args["alias"].columns:
395                query = cte.this
396                if isinstance(query, exp.Union):
397                    query = query.this
398
399                cte.args["alias"].set(
400                    "columns",
401                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
402                )
403
404    return expression
405
406
407def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
408    """Replace 'epoch' in casts by the equivalent date literal."""
409    if (
410        isinstance(expression, (exp.Cast, exp.TryCast))
411        and expression.name.lower() == "epoch"
412        and expression.to.this in exp.DataType.TEMPORAL_TYPES
413    ):
414        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
415
416    return expression
417
418
419def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
420    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
421    if isinstance(expression, exp.Select):
422        for join in expression.args.get("joins") or []:
423            on = join.args.get("on")
424            if on and join.kind in ("SEMI", "ANTI"):
425                subquery = exp.select("1").from_(join.this).where(on)
426                exists = exp.Exists(this=subquery)
427                if join.kind == "ANTI":
428                    exists = exists.not_(copy=False)
429
430                join.pop()
431                expression.where(exists, copy=False)
432
433    return expression
434
435
436def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
437    """
438    Converts a query with a FULL OUTER join to a union of identical queries that
439    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
440    for queries that have a single FULL OUTER join.
441    """
442    if isinstance(expression, exp.Select):
443        full_outer_joins = [
444            (index, join)
445            for index, join in enumerate(expression.args.get("joins") or [])
446            if join.side == "FULL"
447        ]
448
449        if len(full_outer_joins) == 1:
450            expression_copy = expression.copy()
451            expression.set("limit", None)
452            index, full_outer_join = full_outer_joins[0]
453            full_outer_join.set("side", "left")
454            expression_copy.args["joins"][index].set("side", "right")
455            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
456
457            return exp.union(expression, expression_copy, copy=False)
458
459    return expression
460
461
462def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
463    """
464    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
465    defined at the top-level, so for example queries like:
466
467        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
468
469    are invalid in those dialects. This transformation can be used to ensure all CTEs are
470    moved to the top level so that the final SQL code is valid from a syntax standpoint.
471
472    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
473    """
474    top_level_with = expression.args.get("with")
475    for inner_with in expression.find_all(exp.With):
476        if inner_with.parent is expression:
477            continue
478
479        if not top_level_with:
480            top_level_with = inner_with.pop()
481            expression.set("with", top_level_with)
482        else:
483            if inner_with.recursive:
484                top_level_with.set("recursive", True)
485
486            parent_cte = inner_with.find_ancestor(exp.CTE)
487            inner_with.pop()
488
489            if parent_cte:
490                i = top_level_with.expressions.index(parent_cte)
491                top_level_with.expressions[i:i] = inner_with.expressions
492                top_level_with.set("expressions", top_level_with.expressions)
493            else:
494                top_level_with.set(
495                    "expressions", top_level_with.expressions + inner_with.expressions
496                )
497
498    return expression
499
500
501def ensure_bools(expression: exp.Expression) -> exp.Expression:
502    """Converts numeric values used in conditions into explicit boolean expressions."""
503    from sqlglot.optimizer.canonicalize import ensure_bools
504
505    def _ensure_bool(node: exp.Expression) -> None:
506        if (
507            node.is_number
508            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
509            or (isinstance(node, exp.Column) and not node.type)
510        ):
511            node.replace(node.neq(0))
512
513    for node in expression.walk():
514        ensure_bools(node, _ensure_bool)
515
516    return expression
517
518
519def unqualify_columns(expression: exp.Expression) -> exp.Expression:
520    for column in expression.find_all(exp.Column):
521        # We only wanna pop off the table, db, catalog args
522        for part in column.parts[:-1]:
523            part.pop()
524
525    return expression
526
527
528def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
529    assert isinstance(expression, exp.Create)
530    for constraint in expression.find_all(exp.UniqueColumnConstraint):
531        if constraint.parent:
532            constraint.parent.pop()
533
534    return expression
535
536
537def ctas_with_tmp_tables_to_create_tmp_view(
538    expression: exp.Expression,
539    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
540) -> exp.Expression:
541    assert isinstance(expression, exp.Create)
542    properties = expression.args.get("properties")
543    temporary = any(
544        isinstance(prop, exp.TemporaryProperty)
545        for prop in (properties.expressions if properties else [])
546    )
547
548    # CTAS with temp tables map to CREATE TEMPORARY VIEW
549    if expression.kind == "TABLE" and temporary:
550        if expression.expression:
551            return exp.Create(
552                kind="TEMPORARY VIEW",
553                this=expression.this,
554                expression=expression.expression,
555            )
556        return tmp_storage_provider(expression)
557
558    return expression
559
560
561def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
562    """
563    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
564    PARTITIONED BY value is an array of column names, they are transformed into a schema.
565    The corresponding columns are removed from the create statement.
566    """
567    assert isinstance(expression, exp.Create)
568    has_schema = isinstance(expression.this, exp.Schema)
569    is_partitionable = expression.kind in {"TABLE", "VIEW"}
570
571    if has_schema and is_partitionable:
572        prop = expression.find(exp.PartitionedByProperty)
573        if prop and prop.this and not isinstance(prop.this, exp.Schema):
574            schema = expression.this
575            columns = {v.name.upper() for v in prop.this.expressions}
576            partitions = [col for col in schema.expressions if col.name.upper() in columns]
577            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
578            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
579            expression.set("this", schema)
580
581    return expression
582
583
584def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
585    """
586    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
587
588    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
589    """
590    assert isinstance(expression, exp.Create)
591    prop = expression.find(exp.PartitionedByProperty)
592    if (
593        prop
594        and prop.this
595        and isinstance(prop.this, exp.Schema)
596        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
597    ):
598        prop_this = exp.Tuple(
599            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
600        )
601        schema = expression.this
602        for e in prop.this.expressions:
603            schema.append("expressions", e)
604        prop.set("this", prop_this)
605
606    return expression
607
608
609def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
610    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
611    if isinstance(expression, exp.Struct):
612        expression.set(
613            "expressions",
614            [
615                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
616                for e in expression.expressions
617            ],
618        )
619
620    return expression
621
622
623def preprocess(
624    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
625) -> t.Callable[[Generator, exp.Expression], str]:
626    """
627    Creates a new transform by chaining a sequence of transformations and converts the resulting
628    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
629    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
630
631    Args:
632        transforms: sequence of transform functions. These will be called in order.
633
634    Returns:
635        Function that can be used as a generator transform.
636    """
637
638    def _to_sql(self, expression: exp.Expression) -> str:
639        expression_type = type(expression)
640
641        expression = transforms[0](expression)
642        for transform in transforms[1:]:
643            expression = transform(expression)
644
645        _sql_handler = getattr(self, expression.key + "_sql", None)
646        if _sql_handler:
647            return _sql_handler(expression)
648
649        transforms_handler = self.TRANSFORMS.get(type(expression))
650        if transforms_handler:
651            if expression_type is type(expression):
652                if isinstance(expression, exp.Func):
653                    return self.function_fallback_sql(expression)
654
655                # Ensures we don't enter an infinite loop. This can happen when the original expression
656                # has the same type as the final expression and there's no _sql method available for it,
657                # because then it'd re-enter _to_sql.
658                raise ValueError(
659                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
660                )
661
662            return transforms_handler(self, expression)
663
664        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
665
666    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 98    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 99    corresponding expression to avoid creating invalid column references.
100    """
101    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
102        taken = set(expression.named_selects)
103        for select in expression.selects:
104            if not select.alias_or_name:
105                alias = find_new_name(taken, "_c")
106                select.replace(exp.alias_(select, alias))
107                taken.add(alias)
108
109        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
110            alias_or_name = select.alias_or_name
111            identifier = select.args.get("alias") or select.this
112            if isinstance(identifier, exp.Identifier):
113                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
114            return alias_or_name
115
116        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
117        qualify_filters = expression.args["qualify"].pop().this
118        expression_by_alias = {
119            select.alias: select.this
120            for select in expression.selects
121            if isinstance(select, exp.Alias)
122        }
123
124        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
125        for select_candidate in qualify_filters.find_all(select_candidates):
126            if isinstance(select_candidate, exp.Window):
127                if expression_by_alias:
128                    for column in select_candidate.find_all(exp.Column):
129                        expr = expression_by_alias.get(column.name)
130                        if expr:
131                            column.replace(expr)
132
133                alias = find_new_name(expression.named_selects, "_w")
134                expression.select(exp.alias_(select_candidate, alias), copy=False)
135                column = exp.column(alias)
136
137                if isinstance(select_candidate.parent, exp.Qualify):
138                    qualify_filters = column
139                else:
140                    select_candidate.replace(column)
141            elif select_candidate.name not in expression.named_selects:
142                expression.select(select_candidate.copy(), copy=False)
143
144        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
145            qualify_filters, copy=False
146        )
147
148    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
151def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
152    """
153    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
154    other expressions. This transforms removes the precision from parameterized types in expressions.
155    """
156    for node in expression.find_all(exp.DataType):
157        node.set(
158            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
159        )
160
161    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
164def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
165    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
166    from sqlglot.optimizer.scope import find_all_in_scope
167
168    if isinstance(expression, exp.Select):
169        unnest_aliases = {
170            unnest.alias
171            for unnest in find_all_in_scope(expression, exp.Unnest)
172            if isinstance(unnest.parent, (exp.From, exp.Join))
173        }
174        if unnest_aliases:
175            for column in expression.find_all(exp.Column):
176                if column.table in unnest_aliases:
177                    column.set("table", None)
178                elif column.db in unnest_aliases:
179                    column.set("db", None)
180
181    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
184def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
185    """Convert cross join unnest into lateral view explode."""
186    if isinstance(expression, exp.Select):
187        for join in expression.args.get("joins") or []:
188            unnest = join.this
189
190            if isinstance(unnest, exp.Unnest):
191                alias = unnest.args.get("alias")
192                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
193
194                expression.args["joins"].remove(join)
195
196                for e, column in zip(unnest.expressions, alias.columns if alias else []):
197                    expression.append(
198                        "laterals",
199                        exp.Lateral(
200                            this=udtf(this=e),
201                            view=True,
202                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
203                        ),
204                    )
205
206    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
209def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
210    """Convert explode/posexplode into unnest."""
211
212    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
213        if isinstance(expression, exp.Select):
214            from sqlglot.optimizer.scope import Scope
215
216            taken_select_names = set(expression.named_selects)
217            taken_source_names = {name for name, _ in Scope(expression).references}
218
219            def new_name(names: t.Set[str], name: str) -> str:
220                name = find_new_name(names, name)
221                names.add(name)
222                return name
223
224            arrays: t.List[exp.Condition] = []
225            series_alias = new_name(taken_select_names, "pos")
226            series = exp.alias_(
227                exp.Unnest(
228                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
229                ),
230                new_name(taken_source_names, "_u"),
231                table=[series_alias],
232            )
233
234            # we use list here because expression.selects is mutated inside the loop
235            for select in list(expression.selects):
236                explode = select.find(exp.Explode)
237
238                if explode:
239                    pos_alias = ""
240                    explode_alias = ""
241
242                    if isinstance(select, exp.Alias):
243                        explode_alias = select.args["alias"]
244                        alias = select
245                    elif isinstance(select, exp.Aliases):
246                        pos_alias = select.aliases[0]
247                        explode_alias = select.aliases[1]
248                        alias = select.replace(exp.alias_(select.this, "", copy=False))
249                    else:
250                        alias = select.replace(exp.alias_(select, ""))
251                        explode = alias.find(exp.Explode)
252                        assert explode
253
254                    is_posexplode = isinstance(explode, exp.Posexplode)
255                    explode_arg = explode.this
256
257                    if isinstance(explode, exp.ExplodeOuter):
258                        bracket = explode_arg[0]
259                        bracket.set("safe", True)
260                        bracket.set("offset", True)
261                        explode_arg = exp.func(
262                            "IF",
263                            exp.func(
264                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
265                            ).eq(0),
266                            exp.array(bracket, copy=False),
267                            explode_arg,
268                        )
269
270                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
271                    if isinstance(explode_arg, exp.Column):
272                        taken_select_names.add(explode_arg.output_name)
273
274                    unnest_source_alias = new_name(taken_source_names, "_u")
275
276                    if not explode_alias:
277                        explode_alias = new_name(taken_select_names, "col")
278
279                        if is_posexplode:
280                            pos_alias = new_name(taken_select_names, "pos")
281
282                    if not pos_alias:
283                        pos_alias = new_name(taken_select_names, "pos")
284
285                    alias.set("alias", exp.to_identifier(explode_alias))
286
287                    series_table_alias = series.args["alias"].this
288                    column = exp.If(
289                        this=exp.column(series_alias, table=series_table_alias).eq(
290                            exp.column(pos_alias, table=unnest_source_alias)
291                        ),
292                        true=exp.column(explode_alias, table=unnest_source_alias),
293                    )
294
295                    explode.replace(column)
296
297                    if is_posexplode:
298                        expressions = expression.expressions
299                        expressions.insert(
300                            expressions.index(alias) + 1,
301                            exp.If(
302                                this=exp.column(series_alias, table=series_table_alias).eq(
303                                    exp.column(pos_alias, table=unnest_source_alias)
304                                ),
305                                true=exp.column(pos_alias, table=unnest_source_alias),
306                            ).as_(pos_alias),
307                        )
308                        expression.set("expressions", expressions)
309
310                    if not arrays:
311                        if expression.args.get("from"):
312                            expression.join(series, copy=False, join_type="CROSS")
313                        else:
314                            expression.from_(series, copy=False)
315
316                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
317                    arrays.append(size)
318
319                    # trino doesn't support left join unnest with on conditions
320                    # if it did, this would be much simpler
321                    expression.join(
322                        exp.alias_(
323                            exp.Unnest(
324                                expressions=[explode_arg.copy()],
325                                offset=exp.to_identifier(pos_alias),
326                            ),
327                            unnest_source_alias,
328                            table=[explode_alias],
329                        ),
330                        join_type="CROSS",
331                        copy=False,
332                    )
333
334                    if index_offset != 1:
335                        size = size - 1
336
337                    expression.where(
338                        exp.column(series_alias, table=series_table_alias)
339                        .eq(exp.column(pos_alias, table=unnest_source_alias))
340                        .or_(
341                            (exp.column(series_alias, table=series_table_alias) > size).and_(
342                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
343                            )
344                        ),
345                        copy=False,
346                    )
347
348            if arrays:
349                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
350
351                if index_offset != 1:
352                    end = end - (1 - index_offset)
353                series.expressions[0].set("end", end)
354
355        return expression
356
357    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
360def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
361    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
362    if (
363        isinstance(expression, exp.PERCENTILES)
364        and not isinstance(expression.parent, exp.WithinGroup)
365        and expression.expression
366    ):
367        column = expression.this.pop()
368        expression.set("this", expression.expression.pop())
369        order = exp.Order(expressions=[exp.Ordered(this=column)])
370        expression = exp.WithinGroup(this=expression, expression=order)
371
372    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
375def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
376    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
377    if (
378        isinstance(expression, exp.WithinGroup)
379        and isinstance(expression.this, exp.PERCENTILES)
380        and isinstance(expression.expression, exp.Order)
381    ):
382        quantile = expression.this.this
383        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
384        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
385
386    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
389def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
390    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
391    if isinstance(expression, exp.With) and expression.recursive:
392        next_name = name_sequence("_c_")
393
394        for cte in expression.expressions:
395            if not cte.args["alias"].columns:
396                query = cte.this
397                if isinstance(query, exp.Union):
398                    query = query.this
399
400                cte.args["alias"].set(
401                    "columns",
402                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
403                )
404
405    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
408def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
409    """Replace 'epoch' in casts by the equivalent date literal."""
410    if (
411        isinstance(expression, (exp.Cast, exp.TryCast))
412        and expression.name.lower() == "epoch"
413        and expression.to.this in exp.DataType.TEMPORAL_TYPES
414    ):
415        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
416
417    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
420def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
421    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
422    if isinstance(expression, exp.Select):
423        for join in expression.args.get("joins") or []:
424            on = join.args.get("on")
425            if on and join.kind in ("SEMI", "ANTI"):
426                subquery = exp.select("1").from_(join.this).where(on)
427                exists = exp.Exists(this=subquery)
428                if join.kind == "ANTI":
429                    exists = exists.not_(copy=False)
430
431                join.pop()
432                expression.where(exists, copy=False)
433
434    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
437def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
438    """
439    Converts a query with a FULL OUTER join to a union of identical queries that
440    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
441    for queries that have a single FULL OUTER join.
442    """
443    if isinstance(expression, exp.Select):
444        full_outer_joins = [
445            (index, join)
446            for index, join in enumerate(expression.args.get("joins") or [])
447            if join.side == "FULL"
448        ]
449
450        if len(full_outer_joins) == 1:
451            expression_copy = expression.copy()
452            expression.set("limit", None)
453            index, full_outer_join = full_outer_joins[0]
454            full_outer_join.set("side", "left")
455            expression_copy.args["joins"][index].set("side", "right")
456            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
457
458            return exp.union(expression, expression_copy, copy=False)
459
460    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
463def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
464    """
465    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
466    defined at the top-level, so for example queries like:
467
468        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
469
470    are invalid in those dialects. This transformation can be used to ensure all CTEs are
471    moved to the top level so that the final SQL code is valid from a syntax standpoint.
472
473    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
474    """
475    top_level_with = expression.args.get("with")
476    for inner_with in expression.find_all(exp.With):
477        if inner_with.parent is expression:
478            continue
479
480        if not top_level_with:
481            top_level_with = inner_with.pop()
482            expression.set("with", top_level_with)
483        else:
484            if inner_with.recursive:
485                top_level_with.set("recursive", True)
486
487            parent_cte = inner_with.find_ancestor(exp.CTE)
488            inner_with.pop()
489
490            if parent_cte:
491                i = top_level_with.expressions.index(parent_cte)
492                top_level_with.expressions[i:i] = inner_with.expressions
493                top_level_with.set("expressions", top_level_with.expressions)
494            else:
495                top_level_with.set(
496                    "expressions", top_level_with.expressions + inner_with.expressions
497                )
498
499    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
502def ensure_bools(expression: exp.Expression) -> exp.Expression:
503    """Converts numeric values used in conditions into explicit boolean expressions."""
504    from sqlglot.optimizer.canonicalize import ensure_bools
505
506    def _ensure_bool(node: exp.Expression) -> None:
507        if (
508            node.is_number
509            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
510            or (isinstance(node, exp.Column) and not node.type)
511        ):
512            node.replace(node.neq(0))
513
514    for node in expression.walk():
515        ensure_bools(node, _ensure_bool)
516
517    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
520def unqualify_columns(expression: exp.Expression) -> exp.Expression:
521    for column in expression.find_all(exp.Column):
522        # We only wanna pop off the table, db, catalog args
523        for part in column.parts[:-1]:
524            part.pop()
525
526    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
529def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
530    assert isinstance(expression, exp.Create)
531    for constraint in expression.find_all(exp.UniqueColumnConstraint):
532        if constraint.parent:
533            constraint.parent.pop()
534
535    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
538def ctas_with_tmp_tables_to_create_tmp_view(
539    expression: exp.Expression,
540    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
541) -> exp.Expression:
542    assert isinstance(expression, exp.Create)
543    properties = expression.args.get("properties")
544    temporary = any(
545        isinstance(prop, exp.TemporaryProperty)
546        for prop in (properties.expressions if properties else [])
547    )
548
549    # CTAS with temp tables map to CREATE TEMPORARY VIEW
550    if expression.kind == "TABLE" and temporary:
551        if expression.expression:
552            return exp.Create(
553                kind="TEMPORARY VIEW",
554                this=expression.this,
555                expression=expression.expression,
556            )
557        return tmp_storage_provider(expression)
558
559    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
562def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
563    """
564    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
565    PARTITIONED BY value is an array of column names, they are transformed into a schema.
566    The corresponding columns are removed from the create statement.
567    """
568    assert isinstance(expression, exp.Create)
569    has_schema = isinstance(expression.this, exp.Schema)
570    is_partitionable = expression.kind in {"TABLE", "VIEW"}
571
572    if has_schema and is_partitionable:
573        prop = expression.find(exp.PartitionedByProperty)
574        if prop and prop.this and not isinstance(prop.this, exp.Schema):
575            schema = expression.this
576            columns = {v.name.upper() for v in prop.this.expressions}
577            partitions = [col for col in schema.expressions if col.name.upper() in columns]
578            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
579            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
580            expression.set("this", schema)
581
582    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
585def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
586    """
587    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
588
589    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
590    """
591    assert isinstance(expression, exp.Create)
592    prop = expression.find(exp.PartitionedByProperty)
593    if (
594        prop
595        and prop.this
596        and isinstance(prop.this, exp.Schema)
597        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
598    ):
599        prop_this = exp.Tuple(
600            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
601        )
602        schema = expression.this
603        for e in prop.this.expressions:
604            schema.append("expressions", e)
605        prop.set("this", prop_this)
606
607    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
610def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
611    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
612    if isinstance(expression, exp.Struct):
613        expression.set(
614            "expressions",
615            [
616                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
617                for e in expression.expressions
618            ],
619        )
620
621    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
624def preprocess(
625    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
626) -> t.Callable[[Generator, exp.Expression], str]:
627    """
628    Creates a new transform by chaining a sequence of transformations and converts the resulting
629    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
630    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
631
632    Args:
633        transforms: sequence of transform functions. These will be called in order.
634
635    Returns:
636        Function that can be used as a generator transform.
637    """
638
639    def _to_sql(self, expression: exp.Expression) -> str:
640        expression_type = type(expression)
641
642        expression = transforms[0](expression)
643        for transform in transforms[1:]:
644            expression = transform(expression)
645
646        _sql_handler = getattr(self, expression.key + "_sql", None)
647        if _sql_handler:
648            return _sql_handler(expression)
649
650        transforms_handler = self.TRANSFORMS.get(type(expression))
651        if transforms_handler:
652            if expression_type is type(expression):
653                if isinstance(expression, exp.Func):
654                    return self.function_fallback_sql(expression)
655
656                # Ensures we don't enter an infinite loop. This can happen when the original expression
657                # has the same type as the final expression and there's no _sql method available for it,
658                # because then it'd re-enter _to_sql.
659                raise ValueError(
660                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
661                )
662
663            return transforms_handler(self, expression)
664
665        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
666
667    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.