sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 72 window = exp.alias_(window, row_number) 73 expression.select(window, copy=False) 74 75 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 76 77 return expression 78 79 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 81 """ 82 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 83 84 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 85 https://docs.snowflake.com/en/sql-reference/constructs/qualify 86 87 Some dialects don't support window functions in the WHERE clause, so we need to include them as 88 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 89 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 90 otherwise we won't be able to refer to it in the outer query's WHERE clause. 91 """ 92 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 93 taken = set(expression.named_selects) 94 for select in expression.selects: 95 if not select.alias_or_name: 96 alias = find_new_name(taken, "_c") 97 select.replace(exp.alias_(select, alias)) 98 taken.add(alias) 99 100 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 101 qualify_filters = expression.args["qualify"].pop().this 102 103 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 104 if isinstance(expr, exp.Window): 105 alias = find_new_name(expression.named_selects, "_w") 106 expression.select(exp.alias_(expr, alias), copy=False) 107 column = exp.column(alias) 108 109 if isinstance(expr.parent, exp.Qualify): 110 qualify_filters = column 111 else: 112 expr.replace(column) 113 elif expr.name not in expression.named_selects: 114 expression.select(expr.copy(), copy=False) 115 116 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 117 118 return expression 119 120 121def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 122 """ 123 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 124 other expressions. This transforms removes the precision from parameterized types in expressions. 125 """ 126 for node in expression.find_all(exp.DataType): 127 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 128 129 return expression 130 131 132def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 133 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 134 if isinstance(expression, exp.Select): 135 for join in expression.args.get("joins") or []: 136 unnest = join.this 137 138 if isinstance(unnest, exp.Unnest): 139 alias = unnest.args.get("alias") 140 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 141 142 expression.args["joins"].remove(join) 143 144 for e, column in zip(unnest.expressions, alias.columns if alias else []): 145 expression.append( 146 "laterals", 147 exp.Lateral( 148 this=udtf(this=e), 149 view=True, 150 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 151 ), 152 ) 153 154 return expression 155 156 157def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 158 """Convert explode/posexplode into unnest (used in hive -> presto).""" 159 if isinstance(expression, exp.Select): 160 from sqlglot.optimizer.scope import build_scope 161 162 taken_select_names = set(expression.named_selects) 163 scope = build_scope(expression) 164 if not scope: 165 return expression 166 taken_source_names = set(scope.selected_sources) 167 168 for select in expression.selects: 169 to_replace = select 170 171 pos_alias = "" 172 explode_alias = "" 173 174 if isinstance(select, exp.Alias): 175 explode_alias = select.alias 176 select = select.this 177 elif isinstance(select, exp.Aliases): 178 pos_alias = select.aliases[0].name 179 explode_alias = select.aliases[1].name 180 select = select.this 181 182 if isinstance(select, (exp.Explode, exp.Posexplode)): 183 is_posexplode = isinstance(select, exp.Posexplode) 184 185 explode_arg = select.this 186 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 187 188 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 189 if isinstance(explode_arg, exp.Column): 190 taken_select_names.add(explode_arg.output_name) 191 192 unnest_source_alias = find_new_name(taken_source_names, "_u") 193 taken_source_names.add(unnest_source_alias) 194 195 if not explode_alias: 196 explode_alias = find_new_name(taken_select_names, "col") 197 taken_select_names.add(explode_alias) 198 199 if is_posexplode: 200 pos_alias = find_new_name(taken_select_names, "pos") 201 taken_select_names.add(pos_alias) 202 203 if is_posexplode: 204 column_names = [explode_alias, pos_alias] 205 to_replace.pop() 206 expression.select(pos_alias, explode_alias, copy=False) 207 else: 208 column_names = [explode_alias] 209 to_replace.replace(exp.column(explode_alias)) 210 211 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 212 213 if not expression.args.get("from"): 214 expression.from_(unnest, copy=False) 215 else: 216 expression.join(unnest, join_type="CROSS", copy=False) 217 218 return expression 219 220 221def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 222 """Remove table refs from columns in when statements.""" 223 if isinstance(expression, exp.Merge): 224 alias = expression.this.args.get("alias") 225 targets = {expression.this.this} 226 if alias: 227 targets.add(alias.this) 228 229 for when in expression.expressions: 230 when.transform( 231 lambda node: exp.column(node.name) 232 if isinstance(node, exp.Column) and node.args.get("table") in targets 233 else node, 234 copy=False, 235 ) 236 237 return expression 238 239 240def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 241 if ( 242 isinstance(expression, exp.WithinGroup) 243 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 244 and isinstance(expression.expression, exp.Order) 245 ): 246 quantile = expression.this.this 247 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 248 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 249 250 return expression 251 252 253def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 254 if isinstance(expression, exp.With) and expression.recursive: 255 next_name = name_sequence("_c_") 256 257 for cte in expression.expressions: 258 if not cte.args["alias"].columns: 259 query = cte.this 260 if isinstance(query, exp.Union): 261 query = query.this 262 263 cte.args["alias"].set( 264 "columns", 265 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 266 ) 267 268 return expression 269 270 271def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 272 if ( 273 isinstance(expression, (exp.Cast, exp.TryCast)) 274 and expression.name.lower() == "epoch" 275 and expression.to.this in exp.DataType.TEMPORAL_TYPES 276 ): 277 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 278 279 return expression 280 281 282def preprocess( 283 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 284) -> t.Callable[[Generator, exp.Expression], str]: 285 """ 286 Creates a new transform by chaining a sequence of transformations and converts the resulting 287 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 288 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 289 290 Args: 291 transforms: sequence of transform functions. These will be called in order. 292 293 Returns: 294 Function that can be used as a generator transform. 295 """ 296 297 def _to_sql(self, expression: exp.Expression) -> str: 298 expression_type = type(expression) 299 300 expression = transforms[0](expression.copy()) 301 for t in transforms[1:]: 302 expression = t(expression) 303 304 _sql_handler = getattr(self, expression.key + "_sql", None) 305 if _sql_handler: 306 return _sql_handler(expression) 307 308 transforms_handler = self.TRANSFORMS.get(type(expression)) 309 if transforms_handler: 310 # Ensures we don't enter an infinite loop. This can happen when the original expression 311 # has the same type as the final expression and there's no _sql method available for it, 312 # because then it'd re-enter _to_sql. 313 if expression_type is type(expression): 314 raise ValueError( 315 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 316 ) 317 318 return transforms_handler(self, expression) 319 320 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 321 322 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 73 window = exp.alias_(window, row_number) 74 expression.select(window, copy=False) 75 76 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 77 78 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
81def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 82 """ 83 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 84 85 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 86 https://docs.snowflake.com/en/sql-reference/constructs/qualify 87 88 Some dialects don't support window functions in the WHERE clause, so we need to include them as 89 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 90 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 91 otherwise we won't be able to refer to it in the outer query's WHERE clause. 92 """ 93 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 94 taken = set(expression.named_selects) 95 for select in expression.selects: 96 if not select.alias_or_name: 97 alias = find_new_name(taken, "_c") 98 select.replace(exp.alias_(select, alias)) 99 taken.add(alias) 100 101 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 102 qualify_filters = expression.args["qualify"].pop().this 103 104 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 105 if isinstance(expr, exp.Window): 106 alias = find_new_name(expression.named_selects, "_w") 107 expression.select(exp.alias_(expr, alias), copy=False) 108 column = exp.column(alias) 109 110 if isinstance(expr.parent, exp.Qualify): 111 qualify_filters = column 112 else: 113 expr.replace(column) 114 elif expr.name not in expression.named_selects: 115 expression.select(expr.copy(), copy=False) 116 117 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 118 119 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 123 """ 124 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 125 other expressions. This transforms removes the precision from parameterized types in expressions. 126 """ 127 for node in expression.find_all(exp.DataType): 128 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 129 130 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
133def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 134 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 135 if isinstance(expression, exp.Select): 136 for join in expression.args.get("joins") or []: 137 unnest = join.this 138 139 if isinstance(unnest, exp.Unnest): 140 alias = unnest.args.get("alias") 141 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 142 143 expression.args["joins"].remove(join) 144 145 for e, column in zip(unnest.expressions, alias.columns if alias else []): 146 expression.append( 147 "laterals", 148 exp.Lateral( 149 this=udtf(this=e), 150 view=True, 151 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 152 ), 153 ) 154 155 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
158def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 159 """Convert explode/posexplode into unnest (used in hive -> presto).""" 160 if isinstance(expression, exp.Select): 161 from sqlglot.optimizer.scope import build_scope 162 163 taken_select_names = set(expression.named_selects) 164 scope = build_scope(expression) 165 if not scope: 166 return expression 167 taken_source_names = set(scope.selected_sources) 168 169 for select in expression.selects: 170 to_replace = select 171 172 pos_alias = "" 173 explode_alias = "" 174 175 if isinstance(select, exp.Alias): 176 explode_alias = select.alias 177 select = select.this 178 elif isinstance(select, exp.Aliases): 179 pos_alias = select.aliases[0].name 180 explode_alias = select.aliases[1].name 181 select = select.this 182 183 if isinstance(select, (exp.Explode, exp.Posexplode)): 184 is_posexplode = isinstance(select, exp.Posexplode) 185 186 explode_arg = select.this 187 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 188 189 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 190 if isinstance(explode_arg, exp.Column): 191 taken_select_names.add(explode_arg.output_name) 192 193 unnest_source_alias = find_new_name(taken_source_names, "_u") 194 taken_source_names.add(unnest_source_alias) 195 196 if not explode_alias: 197 explode_alias = find_new_name(taken_select_names, "col") 198 taken_select_names.add(explode_alias) 199 200 if is_posexplode: 201 pos_alias = find_new_name(taken_select_names, "pos") 202 taken_select_names.add(pos_alias) 203 204 if is_posexplode: 205 column_names = [explode_alias, pos_alias] 206 to_replace.pop() 207 expression.select(pos_alias, explode_alias, copy=False) 208 else: 209 column_names = [explode_alias] 210 to_replace.replace(exp.column(explode_alias)) 211 212 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 213 214 if not expression.args.get("from"): 215 expression.from_(unnest, copy=False) 216 else: 217 expression.join(unnest, join_type="CROSS", copy=False) 218 219 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
222def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 223 """Remove table refs from columns in when statements.""" 224 if isinstance(expression, exp.Merge): 225 alias = expression.this.args.get("alias") 226 targets = {expression.this.this} 227 if alias: 228 targets.add(alias.this) 229 230 for when in expression.expressions: 231 when.transform( 232 lambda node: exp.column(node.name) 233 if isinstance(node, exp.Column) and node.args.get("table") in targets 234 else node, 235 copy=False, 236 ) 237 238 return expression
Remove table refs from columns in when statements.
241def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 242 if ( 243 isinstance(expression, exp.WithinGroup) 244 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 245 and isinstance(expression.expression, exp.Order) 246 ): 247 quantile = expression.this.this 248 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 249 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 250 251 return expression
254def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 255 if isinstance(expression, exp.With) and expression.recursive: 256 next_name = name_sequence("_c_") 257 258 for cte in expression.expressions: 259 if not cte.args["alias"].columns: 260 query = cte.this 261 if isinstance(query, exp.Union): 262 query = query.this 263 264 cte.args["alias"].set( 265 "columns", 266 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 267 ) 268 269 return expression
272def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 273 if ( 274 isinstance(expression, (exp.Cast, exp.TryCast)) 275 and expression.name.lower() == "epoch" 276 and expression.to.this in exp.DataType.TEMPORAL_TYPES 277 ): 278 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 279 280 return expression
283def preprocess( 284 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 285) -> t.Callable[[Generator, exp.Expression], str]: 286 """ 287 Creates a new transform by chaining a sequence of transformations and converts the resulting 288 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 289 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 290 291 Args: 292 transforms: sequence of transform functions. These will be called in order. 293 294 Returns: 295 Function that can be used as a generator transform. 296 """ 297 298 def _to_sql(self, expression: exp.Expression) -> str: 299 expression_type = type(expression) 300 301 expression = transforms[0](expression.copy()) 302 for t in transforms[1:]: 303 expression = t(expression) 304 305 _sql_handler = getattr(self, expression.key + "_sql", None) 306 if _sql_handler: 307 return _sql_handler(expression) 308 309 transforms_handler = self.TRANSFORMS.get(type(expression)) 310 if transforms_handler: 311 # Ensures we don't enter an infinite loop. This can happen when the original expression 312 # has the same type as the final expression and there's no _sql method available for it, 313 # because then it'd re-enter _to_sql. 314 if expression_type is type(expression): 315 raise ValueError( 316 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 317 ) 318 319 return transforms_handler(self, expression) 320 321 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 322 323 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.