sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 97 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 98 corresponding expression to avoid creating invalid column references. 99 """ 100 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 101 taken = set(expression.named_selects) 102 for select in expression.selects: 103 if not select.alias_or_name: 104 alias = find_new_name(taken, "_c") 105 select.replace(exp.alias_(select, alias)) 106 taken.add(alias) 107 108 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 109 alias_or_name = select.alias_or_name 110 identifier = select.args.get("alias") or select.this 111 if isinstance(identifier, exp.Identifier): 112 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 113 return alias_or_name 114 115 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 116 qualify_filters = expression.args["qualify"].pop().this 117 expression_by_alias = { 118 select.alias: select.this 119 for select in expression.selects 120 if isinstance(select, exp.Alias) 121 } 122 123 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 124 for select_candidate in qualify_filters.find_all(select_candidates): 125 if isinstance(select_candidate, exp.Window): 126 if expression_by_alias: 127 for column in select_candidate.find_all(exp.Column): 128 expr = expression_by_alias.get(column.name) 129 if expr: 130 column.replace(expr) 131 132 alias = find_new_name(expression.named_selects, "_w") 133 expression.select(exp.alias_(select_candidate, alias), copy=False) 134 column = exp.column(alias) 135 136 if isinstance(select_candidate.parent, exp.Qualify): 137 qualify_filters = column 138 else: 139 select_candidate.replace(column) 140 elif select_candidate.name not in expression.named_selects: 141 expression.select(select_candidate.copy(), copy=False) 142 143 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 144 qualify_filters, copy=False 145 ) 146 147 return expression 148 149 150def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 151 """ 152 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 153 other expressions. This transforms removes the precision from parameterized types in expressions. 154 """ 155 for node in expression.find_all(exp.DataType): 156 node.set( 157 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 158 ) 159 160 return expression 161 162 163def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 164 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 165 from sqlglot.optimizer.scope import find_all_in_scope 166 167 if isinstance(expression, exp.Select): 168 unnest_aliases = { 169 unnest.alias 170 for unnest in find_all_in_scope(expression, exp.Unnest) 171 if isinstance(unnest.parent, (exp.From, exp.Join)) 172 } 173 if unnest_aliases: 174 for column in expression.find_all(exp.Column): 175 if column.table in unnest_aliases: 176 column.set("table", None) 177 elif column.db in unnest_aliases: 178 column.set("db", None) 179 180 return expression 181 182 183def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 184 """Convert cross join unnest into lateral view explode.""" 185 if isinstance(expression, exp.Select): 186 for join in expression.args.get("joins") or []: 187 unnest = join.this 188 189 if isinstance(unnest, exp.Unnest): 190 alias = unnest.args.get("alias") 191 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 192 193 expression.args["joins"].remove(join) 194 195 for e, column in zip(unnest.expressions, alias.columns if alias else []): 196 expression.append( 197 "laterals", 198 exp.Lateral( 199 this=udtf(this=e), 200 view=True, 201 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 202 ), 203 ) 204 205 return expression 206 207 208def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 209 """Convert explode/posexplode into unnest.""" 210 211 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 212 if isinstance(expression, exp.Select): 213 from sqlglot.optimizer.scope import Scope 214 215 taken_select_names = set(expression.named_selects) 216 taken_source_names = {name for name, _ in Scope(expression).references} 217 218 def new_name(names: t.Set[str], name: str) -> str: 219 name = find_new_name(names, name) 220 names.add(name) 221 return name 222 223 arrays: t.List[exp.Condition] = [] 224 series_alias = new_name(taken_select_names, "pos") 225 series = exp.alias_( 226 exp.Unnest( 227 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 228 ), 229 new_name(taken_source_names, "_u"), 230 table=[series_alias], 231 ) 232 233 # we use list here because expression.selects is mutated inside the loop 234 for select in list(expression.selects): 235 explode = select.find(exp.Explode) 236 237 if explode: 238 pos_alias = "" 239 explode_alias = "" 240 241 if isinstance(select, exp.Alias): 242 explode_alias = select.args["alias"] 243 alias = select 244 elif isinstance(select, exp.Aliases): 245 pos_alias = select.aliases[0] 246 explode_alias = select.aliases[1] 247 alias = select.replace(exp.alias_(select.this, "", copy=False)) 248 else: 249 alias = select.replace(exp.alias_(select, "")) 250 explode = alias.find(exp.Explode) 251 assert explode 252 253 is_posexplode = isinstance(explode, exp.Posexplode) 254 explode_arg = explode.this 255 256 if isinstance(explode, exp.ExplodeOuter): 257 bracket = explode_arg[0] 258 bracket.set("safe", True) 259 bracket.set("offset", True) 260 explode_arg = exp.func( 261 "IF", 262 exp.func( 263 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 264 ).eq(0), 265 exp.array(bracket, copy=False), 266 explode_arg, 267 ) 268 269 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 270 if isinstance(explode_arg, exp.Column): 271 taken_select_names.add(explode_arg.output_name) 272 273 unnest_source_alias = new_name(taken_source_names, "_u") 274 275 if not explode_alias: 276 explode_alias = new_name(taken_select_names, "col") 277 278 if is_posexplode: 279 pos_alias = new_name(taken_select_names, "pos") 280 281 if not pos_alias: 282 pos_alias = new_name(taken_select_names, "pos") 283 284 alias.set("alias", exp.to_identifier(explode_alias)) 285 286 series_table_alias = series.args["alias"].this 287 column = exp.If( 288 this=exp.column(series_alias, table=series_table_alias).eq( 289 exp.column(pos_alias, table=unnest_source_alias) 290 ), 291 true=exp.column(explode_alias, table=unnest_source_alias), 292 ) 293 294 explode.replace(column) 295 296 if is_posexplode: 297 expressions = expression.expressions 298 expressions.insert( 299 expressions.index(alias) + 1, 300 exp.If( 301 this=exp.column(series_alias, table=series_table_alias).eq( 302 exp.column(pos_alias, table=unnest_source_alias) 303 ), 304 true=exp.column(pos_alias, table=unnest_source_alias), 305 ).as_(pos_alias), 306 ) 307 expression.set("expressions", expressions) 308 309 if not arrays: 310 if expression.args.get("from"): 311 expression.join(series, copy=False, join_type="CROSS") 312 else: 313 expression.from_(series, copy=False) 314 315 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 316 arrays.append(size) 317 318 # trino doesn't support left join unnest with on conditions 319 # if it did, this would be much simpler 320 expression.join( 321 exp.alias_( 322 exp.Unnest( 323 expressions=[explode_arg.copy()], 324 offset=exp.to_identifier(pos_alias), 325 ), 326 unnest_source_alias, 327 table=[explode_alias], 328 ), 329 join_type="CROSS", 330 copy=False, 331 ) 332 333 if index_offset != 1: 334 size = size - 1 335 336 expression.where( 337 exp.column(series_alias, table=series_table_alias) 338 .eq(exp.column(pos_alias, table=unnest_source_alias)) 339 .or_( 340 (exp.column(series_alias, table=series_table_alias) > size).and_( 341 exp.column(pos_alias, table=unnest_source_alias).eq(size) 342 ) 343 ), 344 copy=False, 345 ) 346 347 if arrays: 348 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 349 350 if index_offset != 1: 351 end = end - (1 - index_offset) 352 series.expressions[0].set("end", end) 353 354 return expression 355 356 return _explode_to_unnest 357 358 359def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 360 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 361 if ( 362 isinstance(expression, exp.PERCENTILES) 363 and not isinstance(expression.parent, exp.WithinGroup) 364 and expression.expression 365 ): 366 column = expression.this.pop() 367 expression.set("this", expression.expression.pop()) 368 order = exp.Order(expressions=[exp.Ordered(this=column)]) 369 expression = exp.WithinGroup(this=expression, expression=order) 370 371 return expression 372 373 374def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 375 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 376 if ( 377 isinstance(expression, exp.WithinGroup) 378 and isinstance(expression.this, exp.PERCENTILES) 379 and isinstance(expression.expression, exp.Order) 380 ): 381 quantile = expression.this.this 382 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 383 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 384 385 return expression 386 387 388def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 389 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 390 if isinstance(expression, exp.With) and expression.recursive: 391 next_name = name_sequence("_c_") 392 393 for cte in expression.expressions: 394 if not cte.args["alias"].columns: 395 query = cte.this 396 if isinstance(query, exp.Union): 397 query = query.this 398 399 cte.args["alias"].set( 400 "columns", 401 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 402 ) 403 404 return expression 405 406 407def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 408 """Replace 'epoch' in casts by the equivalent date literal.""" 409 if ( 410 isinstance(expression, (exp.Cast, exp.TryCast)) 411 and expression.name.lower() == "epoch" 412 and expression.to.this in exp.DataType.TEMPORAL_TYPES 413 ): 414 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 415 416 return expression 417 418 419def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 420 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 421 if isinstance(expression, exp.Select): 422 for join in expression.args.get("joins") or []: 423 on = join.args.get("on") 424 if on and join.kind in ("SEMI", "ANTI"): 425 subquery = exp.select("1").from_(join.this).where(on) 426 exists = exp.Exists(this=subquery) 427 if join.kind == "ANTI": 428 exists = exists.not_(copy=False) 429 430 join.pop() 431 expression.where(exists, copy=False) 432 433 return expression 434 435 436def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 437 """ 438 Converts a query with a FULL OUTER join to a union of identical queries that 439 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 440 for queries that have a single FULL OUTER join. 441 """ 442 if isinstance(expression, exp.Select): 443 full_outer_joins = [ 444 (index, join) 445 for index, join in enumerate(expression.args.get("joins") or []) 446 if join.side == "FULL" 447 ] 448 449 if len(full_outer_joins) == 1: 450 expression_copy = expression.copy() 451 expression.set("limit", None) 452 index, full_outer_join = full_outer_joins[0] 453 full_outer_join.set("side", "left") 454 expression_copy.args["joins"][index].set("side", "right") 455 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 456 457 return exp.union(expression, expression_copy, copy=False) 458 459 return expression 460 461 462def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 463 """ 464 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 465 defined at the top-level, so for example queries like: 466 467 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 468 469 are invalid in those dialects. This transformation can be used to ensure all CTEs are 470 moved to the top level so that the final SQL code is valid from a syntax standpoint. 471 472 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 473 """ 474 top_level_with = expression.args.get("with") 475 for inner_with in expression.find_all(exp.With): 476 if inner_with.parent is expression: 477 continue 478 479 if not top_level_with: 480 top_level_with = inner_with.pop() 481 expression.set("with", top_level_with) 482 else: 483 if inner_with.recursive: 484 top_level_with.set("recursive", True) 485 486 parent_cte = inner_with.find_ancestor(exp.CTE) 487 inner_with.pop() 488 489 if parent_cte: 490 i = top_level_with.expressions.index(parent_cte) 491 top_level_with.expressions[i:i] = inner_with.expressions 492 top_level_with.set("expressions", top_level_with.expressions) 493 else: 494 top_level_with.set( 495 "expressions", top_level_with.expressions + inner_with.expressions 496 ) 497 498 return expression 499 500 501def ensure_bools(expression: exp.Expression) -> exp.Expression: 502 """Converts numeric values used in conditions into explicit boolean expressions.""" 503 from sqlglot.optimizer.canonicalize import ensure_bools 504 505 def _ensure_bool(node: exp.Expression) -> None: 506 if ( 507 node.is_number 508 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 509 or (isinstance(node, exp.Column) and not node.type) 510 ): 511 node.replace(node.neq(0)) 512 513 for node in expression.walk(): 514 ensure_bools(node, _ensure_bool) 515 516 return expression 517 518 519def unqualify_columns(expression: exp.Expression) -> exp.Expression: 520 for column in expression.find_all(exp.Column): 521 # We only wanna pop off the table, db, catalog args 522 for part in column.parts[:-1]: 523 part.pop() 524 525 return expression 526 527 528def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 529 assert isinstance(expression, exp.Create) 530 for constraint in expression.find_all(exp.UniqueColumnConstraint): 531 if constraint.parent: 532 constraint.parent.pop() 533 534 return expression 535 536 537def ctas_with_tmp_tables_to_create_tmp_view( 538 expression: exp.Expression, 539 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 540) -> exp.Expression: 541 assert isinstance(expression, exp.Create) 542 properties = expression.args.get("properties") 543 temporary = any( 544 isinstance(prop, exp.TemporaryProperty) 545 for prop in (properties.expressions if properties else []) 546 ) 547 548 # CTAS with temp tables map to CREATE TEMPORARY VIEW 549 if expression.kind == "TABLE" and temporary: 550 if expression.expression: 551 return exp.Create( 552 kind="TEMPORARY VIEW", 553 this=expression.this, 554 expression=expression.expression, 555 ) 556 return tmp_storage_provider(expression) 557 558 return expression 559 560 561def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 562 """ 563 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 564 PARTITIONED BY value is an array of column names, they are transformed into a schema. 565 The corresponding columns are removed from the create statement. 566 """ 567 assert isinstance(expression, exp.Create) 568 has_schema = isinstance(expression.this, exp.Schema) 569 is_partitionable = expression.kind in {"TABLE", "VIEW"} 570 571 if has_schema and is_partitionable: 572 prop = expression.find(exp.PartitionedByProperty) 573 if prop and prop.this and not isinstance(prop.this, exp.Schema): 574 schema = expression.this 575 columns = {v.name.upper() for v in prop.this.expressions} 576 partitions = [col for col in schema.expressions if col.name.upper() in columns] 577 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 578 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 579 expression.set("this", schema) 580 581 return expression 582 583 584def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 585 """ 586 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 587 588 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 589 """ 590 assert isinstance(expression, exp.Create) 591 prop = expression.find(exp.PartitionedByProperty) 592 if ( 593 prop 594 and prop.this 595 and isinstance(prop.this, exp.Schema) 596 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 597 ): 598 prop_this = exp.Tuple( 599 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 600 ) 601 schema = expression.this 602 for e in prop.this.expressions: 603 schema.append("expressions", e) 604 prop.set("this", prop_this) 605 606 return expression 607 608 609def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 610 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 611 if isinstance(expression, exp.Struct): 612 expression.set( 613 "expressions", 614 [ 615 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 616 for e in expression.expressions 617 ], 618 ) 619 620 return expression 621 622 623def preprocess( 624 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 625) -> t.Callable[[Generator, exp.Expression], str]: 626 """ 627 Creates a new transform by chaining a sequence of transformations and converts the resulting 628 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 629 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 630 631 Args: 632 transforms: sequence of transform functions. These will be called in order. 633 634 Returns: 635 Function that can be used as a generator transform. 636 """ 637 638 def _to_sql(self, expression: exp.Expression) -> str: 639 expression_type = type(expression) 640 641 expression = transforms[0](expression) 642 for transform in transforms[1:]: 643 expression = transform(expression) 644 645 _sql_handler = getattr(self, expression.key + "_sql", None) 646 if _sql_handler: 647 return _sql_handler(expression) 648 649 transforms_handler = self.TRANSFORMS.get(type(expression)) 650 if transforms_handler: 651 if expression_type is type(expression): 652 if isinstance(expression, exp.Func): 653 return self.function_fallback_sql(expression) 654 655 # Ensures we don't enter an infinite loop. This can happen when the original expression 656 # has the same type as the final expression and there's no _sql method available for it, 657 # because then it'd re-enter _to_sql. 658 raise ValueError( 659 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 660 ) 661 662 return transforms_handler(self, expression) 663 664 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 665 666 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 98 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 99 corresponding expression to avoid creating invalid column references. 100 """ 101 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 102 taken = set(expression.named_selects) 103 for select in expression.selects: 104 if not select.alias_or_name: 105 alias = find_new_name(taken, "_c") 106 select.replace(exp.alias_(select, alias)) 107 taken.add(alias) 108 109 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 110 alias_or_name = select.alias_or_name 111 identifier = select.args.get("alias") or select.this 112 if isinstance(identifier, exp.Identifier): 113 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 114 return alias_or_name 115 116 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 117 qualify_filters = expression.args["qualify"].pop().this 118 expression_by_alias = { 119 select.alias: select.this 120 for select in expression.selects 121 if isinstance(select, exp.Alias) 122 } 123 124 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 125 for select_candidate in qualify_filters.find_all(select_candidates): 126 if isinstance(select_candidate, exp.Window): 127 if expression_by_alias: 128 for column in select_candidate.find_all(exp.Column): 129 expr = expression_by_alias.get(column.name) 130 if expr: 131 column.replace(expr) 132 133 alias = find_new_name(expression.named_selects, "_w") 134 expression.select(exp.alias_(select_candidate, alias), copy=False) 135 column = exp.column(alias) 136 137 if isinstance(select_candidate.parent, exp.Qualify): 138 qualify_filters = column 139 else: 140 select_candidate.replace(column) 141 elif select_candidate.name not in expression.named_selects: 142 expression.select(select_candidate.copy(), copy=False) 143 144 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 145 qualify_filters, copy=False 146 ) 147 148 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
151def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 152 """ 153 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 154 other expressions. This transforms removes the precision from parameterized types in expressions. 155 """ 156 for node in expression.find_all(exp.DataType): 157 node.set( 158 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 159 ) 160 161 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
164def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 165 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 166 from sqlglot.optimizer.scope import find_all_in_scope 167 168 if isinstance(expression, exp.Select): 169 unnest_aliases = { 170 unnest.alias 171 for unnest in find_all_in_scope(expression, exp.Unnest) 172 if isinstance(unnest.parent, (exp.From, exp.Join)) 173 } 174 if unnest_aliases: 175 for column in expression.find_all(exp.Column): 176 if column.table in unnest_aliases: 177 column.set("table", None) 178 elif column.db in unnest_aliases: 179 column.set("db", None) 180 181 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
184def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 185 """Convert cross join unnest into lateral view explode.""" 186 if isinstance(expression, exp.Select): 187 for join in expression.args.get("joins") or []: 188 unnest = join.this 189 190 if isinstance(unnest, exp.Unnest): 191 alias = unnest.args.get("alias") 192 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 193 194 expression.args["joins"].remove(join) 195 196 for e, column in zip(unnest.expressions, alias.columns if alias else []): 197 expression.append( 198 "laterals", 199 exp.Lateral( 200 this=udtf(this=e), 201 view=True, 202 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 203 ), 204 ) 205 206 return expression
Convert cross join unnest into lateral view explode.
209def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 210 """Convert explode/posexplode into unnest.""" 211 212 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 213 if isinstance(expression, exp.Select): 214 from sqlglot.optimizer.scope import Scope 215 216 taken_select_names = set(expression.named_selects) 217 taken_source_names = {name for name, _ in Scope(expression).references} 218 219 def new_name(names: t.Set[str], name: str) -> str: 220 name = find_new_name(names, name) 221 names.add(name) 222 return name 223 224 arrays: t.List[exp.Condition] = [] 225 series_alias = new_name(taken_select_names, "pos") 226 series = exp.alias_( 227 exp.Unnest( 228 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 229 ), 230 new_name(taken_source_names, "_u"), 231 table=[series_alias], 232 ) 233 234 # we use list here because expression.selects is mutated inside the loop 235 for select in list(expression.selects): 236 explode = select.find(exp.Explode) 237 238 if explode: 239 pos_alias = "" 240 explode_alias = "" 241 242 if isinstance(select, exp.Alias): 243 explode_alias = select.args["alias"] 244 alias = select 245 elif isinstance(select, exp.Aliases): 246 pos_alias = select.aliases[0] 247 explode_alias = select.aliases[1] 248 alias = select.replace(exp.alias_(select.this, "", copy=False)) 249 else: 250 alias = select.replace(exp.alias_(select, "")) 251 explode = alias.find(exp.Explode) 252 assert explode 253 254 is_posexplode = isinstance(explode, exp.Posexplode) 255 explode_arg = explode.this 256 257 if isinstance(explode, exp.ExplodeOuter): 258 bracket = explode_arg[0] 259 bracket.set("safe", True) 260 bracket.set("offset", True) 261 explode_arg = exp.func( 262 "IF", 263 exp.func( 264 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 265 ).eq(0), 266 exp.array(bracket, copy=False), 267 explode_arg, 268 ) 269 270 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 271 if isinstance(explode_arg, exp.Column): 272 taken_select_names.add(explode_arg.output_name) 273 274 unnest_source_alias = new_name(taken_source_names, "_u") 275 276 if not explode_alias: 277 explode_alias = new_name(taken_select_names, "col") 278 279 if is_posexplode: 280 pos_alias = new_name(taken_select_names, "pos") 281 282 if not pos_alias: 283 pos_alias = new_name(taken_select_names, "pos") 284 285 alias.set("alias", exp.to_identifier(explode_alias)) 286 287 series_table_alias = series.args["alias"].this 288 column = exp.If( 289 this=exp.column(series_alias, table=series_table_alias).eq( 290 exp.column(pos_alias, table=unnest_source_alias) 291 ), 292 true=exp.column(explode_alias, table=unnest_source_alias), 293 ) 294 295 explode.replace(column) 296 297 if is_posexplode: 298 expressions = expression.expressions 299 expressions.insert( 300 expressions.index(alias) + 1, 301 exp.If( 302 this=exp.column(series_alias, table=series_table_alias).eq( 303 exp.column(pos_alias, table=unnest_source_alias) 304 ), 305 true=exp.column(pos_alias, table=unnest_source_alias), 306 ).as_(pos_alias), 307 ) 308 expression.set("expressions", expressions) 309 310 if not arrays: 311 if expression.args.get("from"): 312 expression.join(series, copy=False, join_type="CROSS") 313 else: 314 expression.from_(series, copy=False) 315 316 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 317 arrays.append(size) 318 319 # trino doesn't support left join unnest with on conditions 320 # if it did, this would be much simpler 321 expression.join( 322 exp.alias_( 323 exp.Unnest( 324 expressions=[explode_arg.copy()], 325 offset=exp.to_identifier(pos_alias), 326 ), 327 unnest_source_alias, 328 table=[explode_alias], 329 ), 330 join_type="CROSS", 331 copy=False, 332 ) 333 334 if index_offset != 1: 335 size = size - 1 336 337 expression.where( 338 exp.column(series_alias, table=series_table_alias) 339 .eq(exp.column(pos_alias, table=unnest_source_alias)) 340 .or_( 341 (exp.column(series_alias, table=series_table_alias) > size).and_( 342 exp.column(pos_alias, table=unnest_source_alias).eq(size) 343 ) 344 ), 345 copy=False, 346 ) 347 348 if arrays: 349 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 350 351 if index_offset != 1: 352 end = end - (1 - index_offset) 353 series.expressions[0].set("end", end) 354 355 return expression 356 357 return _explode_to_unnest
Convert explode/posexplode into unnest.
360def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 361 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 362 if ( 363 isinstance(expression, exp.PERCENTILES) 364 and not isinstance(expression.parent, exp.WithinGroup) 365 and expression.expression 366 ): 367 column = expression.this.pop() 368 expression.set("this", expression.expression.pop()) 369 order = exp.Order(expressions=[exp.Ordered(this=column)]) 370 expression = exp.WithinGroup(this=expression, expression=order) 371 372 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
375def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 376 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 377 if ( 378 isinstance(expression, exp.WithinGroup) 379 and isinstance(expression.this, exp.PERCENTILES) 380 and isinstance(expression.expression, exp.Order) 381 ): 382 quantile = expression.this.this 383 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 384 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 385 386 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
389def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 390 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 391 if isinstance(expression, exp.With) and expression.recursive: 392 next_name = name_sequence("_c_") 393 394 for cte in expression.expressions: 395 if not cte.args["alias"].columns: 396 query = cte.this 397 if isinstance(query, exp.Union): 398 query = query.this 399 400 cte.args["alias"].set( 401 "columns", 402 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 403 ) 404 405 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
408def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 409 """Replace 'epoch' in casts by the equivalent date literal.""" 410 if ( 411 isinstance(expression, (exp.Cast, exp.TryCast)) 412 and expression.name.lower() == "epoch" 413 and expression.to.this in exp.DataType.TEMPORAL_TYPES 414 ): 415 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 416 417 return expression
Replace 'epoch' in casts by the equivalent date literal.
420def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 421 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 422 if isinstance(expression, exp.Select): 423 for join in expression.args.get("joins") or []: 424 on = join.args.get("on") 425 if on and join.kind in ("SEMI", "ANTI"): 426 subquery = exp.select("1").from_(join.this).where(on) 427 exists = exp.Exists(this=subquery) 428 if join.kind == "ANTI": 429 exists = exists.not_(copy=False) 430 431 join.pop() 432 expression.where(exists, copy=False) 433 434 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
437def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 438 """ 439 Converts a query with a FULL OUTER join to a union of identical queries that 440 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 441 for queries that have a single FULL OUTER join. 442 """ 443 if isinstance(expression, exp.Select): 444 full_outer_joins = [ 445 (index, join) 446 for index, join in enumerate(expression.args.get("joins") or []) 447 if join.side == "FULL" 448 ] 449 450 if len(full_outer_joins) == 1: 451 expression_copy = expression.copy() 452 expression.set("limit", None) 453 index, full_outer_join = full_outer_joins[0] 454 full_outer_join.set("side", "left") 455 expression_copy.args["joins"][index].set("side", "right") 456 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 457 458 return exp.union(expression, expression_copy, copy=False) 459 460 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
463def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 464 """ 465 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 466 defined at the top-level, so for example queries like: 467 468 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 469 470 are invalid in those dialects. This transformation can be used to ensure all CTEs are 471 moved to the top level so that the final SQL code is valid from a syntax standpoint. 472 473 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 474 """ 475 top_level_with = expression.args.get("with") 476 for inner_with in expression.find_all(exp.With): 477 if inner_with.parent is expression: 478 continue 479 480 if not top_level_with: 481 top_level_with = inner_with.pop() 482 expression.set("with", top_level_with) 483 else: 484 if inner_with.recursive: 485 top_level_with.set("recursive", True) 486 487 parent_cte = inner_with.find_ancestor(exp.CTE) 488 inner_with.pop() 489 490 if parent_cte: 491 i = top_level_with.expressions.index(parent_cte) 492 top_level_with.expressions[i:i] = inner_with.expressions 493 top_level_with.set("expressions", top_level_with.expressions) 494 else: 495 top_level_with.set( 496 "expressions", top_level_with.expressions + inner_with.expressions 497 ) 498 499 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
502def ensure_bools(expression: exp.Expression) -> exp.Expression: 503 """Converts numeric values used in conditions into explicit boolean expressions.""" 504 from sqlglot.optimizer.canonicalize import ensure_bools 505 506 def _ensure_bool(node: exp.Expression) -> None: 507 if ( 508 node.is_number 509 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 510 or (isinstance(node, exp.Column) and not node.type) 511 ): 512 node.replace(node.neq(0)) 513 514 for node in expression.walk(): 515 ensure_bools(node, _ensure_bool) 516 517 return expression
Converts numeric values used in conditions into explicit boolean expressions.
538def ctas_with_tmp_tables_to_create_tmp_view( 539 expression: exp.Expression, 540 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 541) -> exp.Expression: 542 assert isinstance(expression, exp.Create) 543 properties = expression.args.get("properties") 544 temporary = any( 545 isinstance(prop, exp.TemporaryProperty) 546 for prop in (properties.expressions if properties else []) 547 ) 548 549 # CTAS with temp tables map to CREATE TEMPORARY VIEW 550 if expression.kind == "TABLE" and temporary: 551 if expression.expression: 552 return exp.Create( 553 kind="TEMPORARY VIEW", 554 this=expression.this, 555 expression=expression.expression, 556 ) 557 return tmp_storage_provider(expression) 558 559 return expression
562def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 563 """ 564 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 565 PARTITIONED BY value is an array of column names, they are transformed into a schema. 566 The corresponding columns are removed from the create statement. 567 """ 568 assert isinstance(expression, exp.Create) 569 has_schema = isinstance(expression.this, exp.Schema) 570 is_partitionable = expression.kind in {"TABLE", "VIEW"} 571 572 if has_schema and is_partitionable: 573 prop = expression.find(exp.PartitionedByProperty) 574 if prop and prop.this and not isinstance(prop.this, exp.Schema): 575 schema = expression.this 576 columns = {v.name.upper() for v in prop.this.expressions} 577 partitions = [col for col in schema.expressions if col.name.upper() in columns] 578 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 579 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 580 expression.set("this", schema) 581 582 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
585def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 586 """ 587 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 588 589 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 590 """ 591 assert isinstance(expression, exp.Create) 592 prop = expression.find(exp.PartitionedByProperty) 593 if ( 594 prop 595 and prop.this 596 and isinstance(prop.this, exp.Schema) 597 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 598 ): 599 prop_this = exp.Tuple( 600 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 601 ) 602 schema = expression.this 603 for e in prop.this.expressions: 604 schema.append("expressions", e) 605 prop.set("this", prop_this) 606 607 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
610def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 611 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 612 if isinstance(expression, exp.Struct): 613 expression.set( 614 "expressions", 615 [ 616 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 617 for e in expression.expressions 618 ], 619 ) 620 621 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
624def preprocess( 625 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 626) -> t.Callable[[Generator, exp.Expression], str]: 627 """ 628 Creates a new transform by chaining a sequence of transformations and converts the resulting 629 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 630 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 631 632 Args: 633 transforms: sequence of transform functions. These will be called in order. 634 635 Returns: 636 Function that can be used as a generator transform. 637 """ 638 639 def _to_sql(self, expression: exp.Expression) -> str: 640 expression_type = type(expression) 641 642 expression = transforms[0](expression) 643 for transform in transforms[1:]: 644 expression = transform(expression) 645 646 _sql_handler = getattr(self, expression.key + "_sql", None) 647 if _sql_handler: 648 return _sql_handler(expression) 649 650 transforms_handler = self.TRANSFORMS.get(type(expression)) 651 if transforms_handler: 652 if expression_type is type(expression): 653 if isinstance(expression, exp.Func): 654 return self.function_fallback_sql(expression) 655 656 # Ensures we don't enter an infinite loop. This can happen when the original expression 657 # has the same type as the final expression and there's no _sql method available for it, 658 # because then it'd re-enter _to_sql. 659 raise ValueError( 660 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 661 ) 662 663 return transforms_handler(self, expression) 664 665 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 666 667 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.