sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects) 79 .from_(expression.subquery("_t")) 80 .where(exp.column(row_number).eq(1)) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 124 125 return expression 126 127 128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 129 """ 130 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 131 other expressions. This transforms removes the precision from parameterized types in expressions. 132 """ 133 for node in expression.find_all(exp.DataType): 134 node.set( 135 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 136 ) 137 138 return expression 139 140 141def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 142 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 143 if isinstance(expression, exp.Select): 144 for join in expression.args.get("joins") or []: 145 unnest = join.this 146 147 if isinstance(unnest, exp.Unnest): 148 alias = unnest.args.get("alias") 149 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 150 151 expression.args["joins"].remove(join) 152 153 for e, column in zip(unnest.expressions, alias.columns if alias else []): 154 expression.append( 155 "laterals", 156 exp.Lateral( 157 this=udtf(this=e), 158 view=True, 159 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 160 ), 161 ) 162 163 return expression 164 165 166def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 167 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 168 """Convert explode/posexplode into unnest (used in hive -> presto).""" 169 if isinstance(expression, exp.Select): 170 from sqlglot.optimizer.scope import Scope 171 172 taken_select_names = set(expression.named_selects) 173 taken_source_names = {name for name, _ in Scope(expression).references} 174 175 def new_name(names: t.Set[str], name: str) -> str: 176 name = find_new_name(names, name) 177 names.add(name) 178 return name 179 180 arrays: t.List[exp.Condition] = [] 181 series_alias = new_name(taken_select_names, "pos") 182 series = exp.alias_( 183 exp.Unnest( 184 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 185 ), 186 new_name(taken_source_names, "_u"), 187 table=[series_alias], 188 ) 189 190 # we use list here because expression.selects is mutated inside the loop 191 for select in expression.selects.copy(): 192 explode = select.find(exp.Explode) 193 194 if explode: 195 pos_alias = "" 196 explode_alias = "" 197 198 if isinstance(select, exp.Alias): 199 explode_alias = select.alias 200 alias = select 201 elif isinstance(select, exp.Aliases): 202 pos_alias = select.aliases[0].name 203 explode_alias = select.aliases[1].name 204 alias = select.replace(exp.alias_(select.this, "", copy=False)) 205 else: 206 alias = select.replace(exp.alias_(select, "")) 207 explode = alias.find(exp.Explode) 208 assert explode 209 210 is_posexplode = isinstance(explode, exp.Posexplode) 211 explode_arg = explode.this 212 213 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 214 if isinstance(explode_arg, exp.Column): 215 taken_select_names.add(explode_arg.output_name) 216 217 unnest_source_alias = new_name(taken_source_names, "_u") 218 219 if not explode_alias: 220 explode_alias = new_name(taken_select_names, "col") 221 222 if is_posexplode: 223 pos_alias = new_name(taken_select_names, "pos") 224 225 if not pos_alias: 226 pos_alias = new_name(taken_select_names, "pos") 227 228 alias.set("alias", exp.to_identifier(explode_alias)) 229 230 column = exp.If( 231 this=exp.column(series_alias).eq(exp.column(pos_alias)), 232 true=exp.column(explode_alias), 233 ) 234 235 explode.replace(column) 236 237 if is_posexplode: 238 expressions = expression.expressions 239 expressions.insert( 240 expressions.index(alias) + 1, 241 exp.If( 242 this=exp.column(series_alias).eq(exp.column(pos_alias)), 243 true=exp.column(pos_alias), 244 ).as_(pos_alias), 245 ) 246 expression.set("expressions", expressions) 247 248 if not arrays: 249 if expression.args.get("from"): 250 expression.join(series, copy=False) 251 else: 252 expression.from_(series, copy=False) 253 254 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 255 arrays.append(size) 256 257 # trino doesn't support left join unnest with on conditions 258 # if it did, this would be much simpler 259 expression.join( 260 exp.alias_( 261 exp.Unnest( 262 expressions=[explode_arg.copy()], 263 offset=exp.to_identifier(pos_alias), 264 ), 265 unnest_source_alias, 266 table=[explode_alias], 267 ), 268 join_type="CROSS", 269 copy=False, 270 ) 271 272 if index_offset != 1: 273 size = size - 1 274 275 expression.where( 276 exp.column(series_alias) 277 .eq(exp.column(pos_alias)) 278 .or_( 279 (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) 280 ), 281 copy=False, 282 ) 283 284 if arrays: 285 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 286 287 if index_offset != 1: 288 end = end - (1 - index_offset) 289 series.expressions[0].set("end", end) 290 291 return expression 292 293 return _explode_to_unnest 294 295 296PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 297 298 299def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 300 if ( 301 isinstance(expression, PERCENTILES) 302 and not isinstance(expression.parent, exp.WithinGroup) 303 and expression.expression 304 ): 305 column = expression.this.pop() 306 expression.set("this", expression.expression.pop()) 307 order = exp.Order(expressions=[exp.Ordered(this=column)]) 308 expression = exp.WithinGroup(this=expression, expression=order) 309 310 return expression 311 312 313def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 314 if ( 315 isinstance(expression, exp.WithinGroup) 316 and isinstance(expression.this, PERCENTILES) 317 and isinstance(expression.expression, exp.Order) 318 ): 319 quantile = expression.this.this 320 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 321 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 322 323 return expression 324 325 326def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 327 if isinstance(expression, exp.With) and expression.recursive: 328 next_name = name_sequence("_c_") 329 330 for cte in expression.expressions: 331 if not cte.args["alias"].columns: 332 query = cte.this 333 if isinstance(query, exp.Union): 334 query = query.this 335 336 cte.args["alias"].set( 337 "columns", 338 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 339 ) 340 341 return expression 342 343 344def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 345 if ( 346 isinstance(expression, (exp.Cast, exp.TryCast)) 347 and expression.name.lower() == "epoch" 348 and expression.to.this in exp.DataType.TEMPORAL_TYPES 349 ): 350 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 351 352 return expression 353 354 355def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 356 if isinstance(expression, exp.Select): 357 for join in expression.args.get("joins") or []: 358 on = join.args.get("on") 359 if on and join.kind in ("SEMI", "ANTI"): 360 subquery = exp.select("1").from_(join.this).where(on) 361 exists = exp.Exists(this=subquery) 362 if join.kind == "ANTI": 363 exists = exists.not_(copy=False) 364 365 join.pop() 366 expression.where(exists, copy=False) 367 368 return expression 369 370 371def preprocess( 372 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 373) -> t.Callable[[Generator, exp.Expression], str]: 374 """ 375 Creates a new transform by chaining a sequence of transformations and converts the resulting 376 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 377 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 378 379 Args: 380 transforms: sequence of transform functions. These will be called in order. 381 382 Returns: 383 Function that can be used as a generator transform. 384 """ 385 386 def _to_sql(self, expression: exp.Expression) -> str: 387 expression_type = type(expression) 388 389 expression = transforms[0](expression.copy()) 390 for t in transforms[1:]: 391 expression = t(expression) 392 393 _sql_handler = getattr(self, expression.key + "_sql", None) 394 if _sql_handler: 395 return _sql_handler(expression) 396 397 transforms_handler = self.TRANSFORMS.get(type(expression)) 398 if transforms_handler: 399 if expression_type is type(expression): 400 if isinstance(expression, exp.Func): 401 return self.function_fallback_sql(expression) 402 403 # Ensures we don't enter an infinite loop. This can happen when the original expression 404 # has the same type as the final expression and there's no _sql method available for it, 405 # because then it'd re-enter _to_sql. 406 raise ValueError( 407 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 408 ) 409 410 return transforms_handler(self, expression) 411 412 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 413 414 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects) 80 .from_(expression.subquery("_t")) 81 .where(exp.column(row_number).eq(1)) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 125 126 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 130 """ 131 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 132 other expressions. This transforms removes the precision from parameterized types in expressions. 133 """ 134 for node in expression.find_all(exp.DataType): 135 node.set( 136 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 137 ) 138 139 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 143 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 144 if isinstance(expression, exp.Select): 145 for join in expression.args.get("joins") or []: 146 unnest = join.this 147 148 if isinstance(unnest, exp.Unnest): 149 alias = unnest.args.get("alias") 150 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 151 152 expression.args["joins"].remove(join) 153 154 for e, column in zip(unnest.expressions, alias.columns if alias else []): 155 expression.append( 156 "laterals", 157 exp.Lateral( 158 this=udtf(this=e), 159 view=True, 160 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 161 ), 162 ) 163 164 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
167def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 168 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 169 """Convert explode/posexplode into unnest (used in hive -> presto).""" 170 if isinstance(expression, exp.Select): 171 from sqlglot.optimizer.scope import Scope 172 173 taken_select_names = set(expression.named_selects) 174 taken_source_names = {name for name, _ in Scope(expression).references} 175 176 def new_name(names: t.Set[str], name: str) -> str: 177 name = find_new_name(names, name) 178 names.add(name) 179 return name 180 181 arrays: t.List[exp.Condition] = [] 182 series_alias = new_name(taken_select_names, "pos") 183 series = exp.alias_( 184 exp.Unnest( 185 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 186 ), 187 new_name(taken_source_names, "_u"), 188 table=[series_alias], 189 ) 190 191 # we use list here because expression.selects is mutated inside the loop 192 for select in expression.selects.copy(): 193 explode = select.find(exp.Explode) 194 195 if explode: 196 pos_alias = "" 197 explode_alias = "" 198 199 if isinstance(select, exp.Alias): 200 explode_alias = select.alias 201 alias = select 202 elif isinstance(select, exp.Aliases): 203 pos_alias = select.aliases[0].name 204 explode_alias = select.aliases[1].name 205 alias = select.replace(exp.alias_(select.this, "", copy=False)) 206 else: 207 alias = select.replace(exp.alias_(select, "")) 208 explode = alias.find(exp.Explode) 209 assert explode 210 211 is_posexplode = isinstance(explode, exp.Posexplode) 212 explode_arg = explode.this 213 214 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 215 if isinstance(explode_arg, exp.Column): 216 taken_select_names.add(explode_arg.output_name) 217 218 unnest_source_alias = new_name(taken_source_names, "_u") 219 220 if not explode_alias: 221 explode_alias = new_name(taken_select_names, "col") 222 223 if is_posexplode: 224 pos_alias = new_name(taken_select_names, "pos") 225 226 if not pos_alias: 227 pos_alias = new_name(taken_select_names, "pos") 228 229 alias.set("alias", exp.to_identifier(explode_alias)) 230 231 column = exp.If( 232 this=exp.column(series_alias).eq(exp.column(pos_alias)), 233 true=exp.column(explode_alias), 234 ) 235 236 explode.replace(column) 237 238 if is_posexplode: 239 expressions = expression.expressions 240 expressions.insert( 241 expressions.index(alias) + 1, 242 exp.If( 243 this=exp.column(series_alias).eq(exp.column(pos_alias)), 244 true=exp.column(pos_alias), 245 ).as_(pos_alias), 246 ) 247 expression.set("expressions", expressions) 248 249 if not arrays: 250 if expression.args.get("from"): 251 expression.join(series, copy=False) 252 else: 253 expression.from_(series, copy=False) 254 255 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 256 arrays.append(size) 257 258 # trino doesn't support left join unnest with on conditions 259 # if it did, this would be much simpler 260 expression.join( 261 exp.alias_( 262 exp.Unnest( 263 expressions=[explode_arg.copy()], 264 offset=exp.to_identifier(pos_alias), 265 ), 266 unnest_source_alias, 267 table=[explode_alias], 268 ), 269 join_type="CROSS", 270 copy=False, 271 ) 272 273 if index_offset != 1: 274 size = size - 1 275 276 expression.where( 277 exp.column(series_alias) 278 .eq(exp.column(pos_alias)) 279 .or_( 280 (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) 281 ), 282 copy=False, 283 ) 284 285 if arrays: 286 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 287 288 if index_offset != 1: 289 end = end - (1 - index_offset) 290 series.expressions[0].set("end", end) 291 292 return expression 293 294 return _explode_to_unnest
300def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 301 if ( 302 isinstance(expression, PERCENTILES) 303 and not isinstance(expression.parent, exp.WithinGroup) 304 and expression.expression 305 ): 306 column = expression.this.pop() 307 expression.set("this", expression.expression.pop()) 308 order = exp.Order(expressions=[exp.Ordered(this=column)]) 309 expression = exp.WithinGroup(this=expression, expression=order) 310 311 return expression
314def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 315 if ( 316 isinstance(expression, exp.WithinGroup) 317 and isinstance(expression.this, PERCENTILES) 318 and isinstance(expression.expression, exp.Order) 319 ): 320 quantile = expression.this.this 321 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 322 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 323 324 return expression
327def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 328 if isinstance(expression, exp.With) and expression.recursive: 329 next_name = name_sequence("_c_") 330 331 for cte in expression.expressions: 332 if not cte.args["alias"].columns: 333 query = cte.this 334 if isinstance(query, exp.Union): 335 query = query.this 336 337 cte.args["alias"].set( 338 "columns", 339 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 340 ) 341 342 return expression
345def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 346 if ( 347 isinstance(expression, (exp.Cast, exp.TryCast)) 348 and expression.name.lower() == "epoch" 349 and expression.to.this in exp.DataType.TEMPORAL_TYPES 350 ): 351 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 352 353 return expression
356def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 357 if isinstance(expression, exp.Select): 358 for join in expression.args.get("joins") or []: 359 on = join.args.get("on") 360 if on and join.kind in ("SEMI", "ANTI"): 361 subquery = exp.select("1").from_(join.this).where(on) 362 exists = exp.Exists(this=subquery) 363 if join.kind == "ANTI": 364 exists = exists.not_(copy=False) 365 366 join.pop() 367 expression.where(exists, copy=False) 368 369 return expression
372def preprocess( 373 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 374) -> t.Callable[[Generator, exp.Expression], str]: 375 """ 376 Creates a new transform by chaining a sequence of transformations and converts the resulting 377 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 378 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 379 380 Args: 381 transforms: sequence of transform functions. These will be called in order. 382 383 Returns: 384 Function that can be used as a generator transform. 385 """ 386 387 def _to_sql(self, expression: exp.Expression) -> str: 388 expression_type = type(expression) 389 390 expression = transforms[0](expression.copy()) 391 for t in transforms[1:]: 392 expression = t(expression) 393 394 _sql_handler = getattr(self, expression.key + "_sql", None) 395 if _sql_handler: 396 return _sql_handler(expression) 397 398 transforms_handler = self.TRANSFORMS.get(type(expression)) 399 if transforms_handler: 400 if expression_type is type(expression): 401 if isinstance(expression, exp.Func): 402 return self.function_fallback_sql(expression) 403 404 # Ensures we don't enter an infinite loop. This can happen when the original expression 405 # has the same type as the final expression and there's no _sql method available for it, 406 # because then it'd re-enter _to_sql. 407 raise ValueError( 408 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 409 ) 410 411 return transforms_handler(self, expression) 412 413 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 414 415 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.