sqlglot.transforms
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.helper import find_new_name 8 9if t.TYPE_CHECKING: 10 from sqlglot.generator import Generator 11 12 13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression 44 45 46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 73 window = exp.alias_(window, row_number) 74 expression.select(window, copy=False) 75 76 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 77 78 return expression 79 80 81def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 82 """ 83 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 84 85 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 86 https://docs.snowflake.com/en/sql-reference/constructs/qualify 87 88 Some dialects don't support window functions in the WHERE clause, so we need to include them as 89 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 90 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 91 otherwise we won't be able to refer to it in the outer query's WHERE clause. 92 """ 93 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 94 taken = set(expression.named_selects) 95 for select in expression.selects: 96 if not select.alias_or_name: 97 alias = find_new_name(taken, "_c") 98 select.replace(exp.alias_(select, alias)) 99 taken.add(alias) 100 101 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 102 qualify_filters = expression.args["qualify"].pop().this 103 104 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 105 if isinstance(expr, exp.Window): 106 alias = find_new_name(expression.named_selects, "_w") 107 expression.select(exp.alias_(expr, alias), copy=False) 108 column = exp.column(alias) 109 110 if isinstance(expr.parent, exp.Qualify): 111 qualify_filters = column 112 else: 113 expr.replace(column) 114 elif expr.name not in expression.named_selects: 115 expression.select(expr.copy(), copy=False) 116 117 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 118 119 return expression 120 121 122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 123 """ 124 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 125 other expressions. This transforms removes the precision from parameterized types in expressions. 126 """ 127 for node in expression.find_all(exp.DataType): 128 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 129 130 return expression 131 132 133def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 134 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 135 if isinstance(expression, exp.Select): 136 for join in expression.args.get("joins") or []: 137 unnest = join.this 138 139 if isinstance(unnest, exp.Unnest): 140 alias = unnest.args.get("alias") 141 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 142 143 expression.args["joins"].remove(join) 144 145 for e, column in zip(unnest.expressions, alias.columns if alias else []): 146 expression.append( 147 "laterals", 148 exp.Lateral( 149 this=udtf(this=e), 150 view=True, 151 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 152 ), 153 ) 154 155 return expression 156 157 158def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 159 """Convert explode/posexplode into unnest (used in hive -> presto).""" 160 if isinstance(expression, exp.Select): 161 from sqlglot.optimizer.scope import build_scope 162 163 taken_select_names = set(expression.named_selects) 164 scope = build_scope(expression) 165 if not scope: 166 return expression 167 taken_source_names = set(scope.selected_sources) 168 169 for select in expression.selects: 170 to_replace = select 171 172 pos_alias = "" 173 explode_alias = "" 174 175 if isinstance(select, exp.Alias): 176 explode_alias = select.alias 177 select = select.this 178 elif isinstance(select, exp.Aliases): 179 pos_alias = select.aliases[0].name 180 explode_alias = select.aliases[1].name 181 select = select.this 182 183 if isinstance(select, (exp.Explode, exp.Posexplode)): 184 is_posexplode = isinstance(select, exp.Posexplode) 185 186 explode_arg = select.this 187 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 188 189 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 190 if isinstance(explode_arg, exp.Column): 191 taken_select_names.add(explode_arg.output_name) 192 193 unnest_source_alias = find_new_name(taken_source_names, "_u") 194 taken_source_names.add(unnest_source_alias) 195 196 if not explode_alias: 197 explode_alias = find_new_name(taken_select_names, "col") 198 taken_select_names.add(explode_alias) 199 200 if is_posexplode: 201 pos_alias = find_new_name(taken_select_names, "pos") 202 taken_select_names.add(pos_alias) 203 204 if is_posexplode: 205 column_names = [explode_alias, pos_alias] 206 to_replace.pop() 207 expression.select(pos_alias, explode_alias, copy=False) 208 else: 209 column_names = [explode_alias] 210 to_replace.replace(exp.column(explode_alias)) 211 212 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 213 214 if not expression.args.get("from"): 215 expression.from_(unnest, copy=False) 216 else: 217 expression.join(unnest, join_type="CROSS", copy=False) 218 219 return expression 220 221 222def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 223 """Remove table refs from columns in when statements.""" 224 if isinstance(expression, exp.Merge): 225 alias = expression.this.args.get("alias") 226 targets = {expression.this.this} 227 if alias: 228 targets.add(alias.this) 229 230 for when in expression.expressions: 231 when.transform( 232 lambda node: exp.column(node.name) 233 if isinstance(node, exp.Column) and node.args.get("table") in targets 234 else node, 235 copy=False, 236 ) 237 238 return expression 239 240 241def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 242 if ( 243 isinstance(expression, exp.WithinGroup) 244 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 245 and isinstance(expression.expression, exp.Order) 246 ): 247 quantile = expression.this.this 248 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 249 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 250 251 return expression 252 253 254def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 255 if isinstance(expression, exp.With) and expression.recursive: 256 sequence = itertools.count() 257 next_name = lambda: f"_c_{next(sequence)}" 258 259 for cte in expression.expressions: 260 if not cte.args["alias"].columns: 261 query = cte.this 262 if isinstance(query, exp.Union): 263 query = query.this 264 265 cte.args["alias"].set( 266 "columns", 267 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 268 ) 269 270 return expression 271 272 273def preprocess( 274 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 275) -> t.Callable[[Generator, exp.Expression], str]: 276 """ 277 Creates a new transform by chaining a sequence of transformations and converts the resulting 278 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 279 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 280 281 Args: 282 transforms: sequence of transform functions. These will be called in order. 283 284 Returns: 285 Function that can be used as a generator transform. 286 """ 287 288 def _to_sql(self, expression: exp.Expression) -> str: 289 expression_type = type(expression) 290 291 expression = transforms[0](expression.copy()) 292 for t in transforms[1:]: 293 expression = t(expression) 294 295 _sql_handler = getattr(self, expression.key + "_sql", None) 296 if _sql_handler: 297 return _sql_handler(expression) 298 299 transforms_handler = self.TRANSFORMS.get(type(expression)) 300 if transforms_handler: 301 # Ensures we don't enter an infinite loop. This can happen when the original expression 302 # has the same type as the final expression and there's no _sql method available for it, 303 # because then it'd re-enter _to_sql. 304 if expression_type is type(expression): 305 raise ValueError( 306 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 307 ) 308 309 return transforms_handler(self, expression) 310 311 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 312 313 return _to_sql
14def unalias_group(expression: exp.Expression) -> exp.Expression: 15 """ 16 Replace references to select aliases in GROUP BY clauses. 17 18 Example: 19 >>> import sqlglot 20 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 21 'SELECT a AS b FROM x GROUP BY 1' 22 23 Args: 24 expression: the expression that will be transformed. 25 26 Returns: 27 The transformed expression. 28 """ 29 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 30 aliased_selects = { 31 e.alias: i 32 for i, e in enumerate(expression.parent.expressions, start=1) 33 if isinstance(e, exp.Alias) 34 } 35 36 for group_by in expression.expressions: 37 if ( 38 isinstance(group_by, exp.Column) 39 and not group_by.table 40 and group_by.name in aliased_selects 41 ): 42 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 43 44 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
47def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 48 """ 49 Convert SELECT DISTINCT ON statements to a subquery with a window function. 50 51 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 52 53 Args: 54 expression: the expression that will be transformed. 55 56 Returns: 57 The transformed expression. 58 """ 59 if ( 60 isinstance(expression, exp.Select) 61 and expression.args.get("distinct") 62 and expression.args["distinct"].args.get("on") 63 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 64 ): 65 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 66 outer_selects = expression.selects 67 row_number = find_new_name(expression.named_selects, "_row_number") 68 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 69 order = expression.args.get("order") 70 71 if order: 72 window.set("order", order.pop().copy()) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 78 79 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
82def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 83 """ 84 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 85 86 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 87 https://docs.snowflake.com/en/sql-reference/constructs/qualify 88 89 Some dialects don't support window functions in the WHERE clause, so we need to include them as 90 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 91 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 92 otherwise we won't be able to refer to it in the outer query's WHERE clause. 93 """ 94 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 95 taken = set(expression.named_selects) 96 for select in expression.selects: 97 if not select.alias_or_name: 98 alias = find_new_name(taken, "_c") 99 select.replace(exp.alias_(select, alias)) 100 taken.add(alias) 101 102 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 103 qualify_filters = expression.args["qualify"].pop().this 104 105 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 106 if isinstance(expr, exp.Window): 107 alias = find_new_name(expression.named_selects, "_w") 108 expression.select(exp.alias_(expr, alias), copy=False) 109 column = exp.column(alias) 110 111 if isinstance(expr.parent, exp.Qualify): 112 qualify_filters = column 113 else: 114 expr.replace(column) 115 elif expr.name not in expression.named_selects: 116 expression.select(expr.copy(), copy=False) 117 118 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 119 120 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
123def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 124 """ 125 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 126 other expressions. This transforms removes the precision from parameterized types in expressions. 127 """ 128 for node in expression.find_all(exp.DataType): 129 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 130 131 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
134def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 135 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 136 if isinstance(expression, exp.Select): 137 for join in expression.args.get("joins") or []: 138 unnest = join.this 139 140 if isinstance(unnest, exp.Unnest): 141 alias = unnest.args.get("alias") 142 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 143 144 expression.args["joins"].remove(join) 145 146 for e, column in zip(unnest.expressions, alias.columns if alias else []): 147 expression.append( 148 "laterals", 149 exp.Lateral( 150 this=udtf(this=e), 151 view=True, 152 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 153 ), 154 ) 155 156 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
159def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 160 """Convert explode/posexplode into unnest (used in hive -> presto).""" 161 if isinstance(expression, exp.Select): 162 from sqlglot.optimizer.scope import build_scope 163 164 taken_select_names = set(expression.named_selects) 165 scope = build_scope(expression) 166 if not scope: 167 return expression 168 taken_source_names = set(scope.selected_sources) 169 170 for select in expression.selects: 171 to_replace = select 172 173 pos_alias = "" 174 explode_alias = "" 175 176 if isinstance(select, exp.Alias): 177 explode_alias = select.alias 178 select = select.this 179 elif isinstance(select, exp.Aliases): 180 pos_alias = select.aliases[0].name 181 explode_alias = select.aliases[1].name 182 select = select.this 183 184 if isinstance(select, (exp.Explode, exp.Posexplode)): 185 is_posexplode = isinstance(select, exp.Posexplode) 186 187 explode_arg = select.this 188 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 189 190 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 191 if isinstance(explode_arg, exp.Column): 192 taken_select_names.add(explode_arg.output_name) 193 194 unnest_source_alias = find_new_name(taken_source_names, "_u") 195 taken_source_names.add(unnest_source_alias) 196 197 if not explode_alias: 198 explode_alias = find_new_name(taken_select_names, "col") 199 taken_select_names.add(explode_alias) 200 201 if is_posexplode: 202 pos_alias = find_new_name(taken_select_names, "pos") 203 taken_select_names.add(pos_alias) 204 205 if is_posexplode: 206 column_names = [explode_alias, pos_alias] 207 to_replace.pop() 208 expression.select(pos_alias, explode_alias, copy=False) 209 else: 210 column_names = [explode_alias] 211 to_replace.replace(exp.column(explode_alias)) 212 213 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 214 215 if not expression.args.get("from"): 216 expression.from_(unnest, copy=False) 217 else: 218 expression.join(unnest, join_type="CROSS", copy=False) 219 220 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
223def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 224 """Remove table refs from columns in when statements.""" 225 if isinstance(expression, exp.Merge): 226 alias = expression.this.args.get("alias") 227 targets = {expression.this.this} 228 if alias: 229 targets.add(alias.this) 230 231 for when in expression.expressions: 232 when.transform( 233 lambda node: exp.column(node.name) 234 if isinstance(node, exp.Column) and node.args.get("table") in targets 235 else node, 236 copy=False, 237 ) 238 239 return expression
Remove table refs from columns in when statements.
242def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 243 if ( 244 isinstance(expression, exp.WithinGroup) 245 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 246 and isinstance(expression.expression, exp.Order) 247 ): 248 quantile = expression.this.this 249 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 250 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 251 252 return expression
255def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 256 if isinstance(expression, exp.With) and expression.recursive: 257 sequence = itertools.count() 258 next_name = lambda: f"_c_{next(sequence)}" 259 260 for cte in expression.expressions: 261 if not cte.args["alias"].columns: 262 query = cte.this 263 if isinstance(query, exp.Union): 264 query = query.this 265 266 cte.args["alias"].set( 267 "columns", 268 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 269 ) 270 271 return expression
274def preprocess( 275 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 276) -> t.Callable[[Generator, exp.Expression], str]: 277 """ 278 Creates a new transform by chaining a sequence of transformations and converts the resulting 279 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 280 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 281 282 Args: 283 transforms: sequence of transform functions. These will be called in order. 284 285 Returns: 286 Function that can be used as a generator transform. 287 """ 288 289 def _to_sql(self, expression: exp.Expression) -> str: 290 expression_type = type(expression) 291 292 expression = transforms[0](expression.copy()) 293 for t in transforms[1:]: 294 expression = t(expression) 295 296 _sql_handler = getattr(self, expression.key + "_sql", None) 297 if _sql_handler: 298 return _sql_handler(expression) 299 300 transforms_handler = self.TRANSFORMS.get(type(expression)) 301 if transforms_handler: 302 # Ensures we don't enter an infinite loop. This can happen when the original expression 303 # has the same type as the final expression and there's no _sql method available for it, 304 # because then it'd re-enter _to_sql. 305 if expression_type is type(expression): 306 raise ValueError( 307 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 308 ) 309 310 return transforms_handler(self, expression) 311 312 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 313 314 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.