sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 72 window = exp.alias_(window, row_number) 73 expression.select(window, copy=False) 74 75 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 76 77 return expression 78 79 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 81 """ 82 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 83 84 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 85 https://docs.snowflake.com/en/sql-reference/constructs/qualify 86 87 Some dialects don't support window functions in the WHERE clause, so we need to include them as 88 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 89 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 90 otherwise we won't be able to refer to it in the outer query's WHERE clause. 91 """ 92 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 93 taken = set(expression.named_selects) 94 for select in expression.selects: 95 if not select.alias_or_name: 96 alias = find_new_name(taken, "_c") 97 select.replace(exp.alias_(select, alias)) 98 taken.add(alias) 99 100 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 101 qualify_filters = expression.args["qualify"].pop().this 102 103 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 104 for expr in qualify_filters.find_all(select_candidates): 105 if isinstance(expr, exp.Window): 106 alias = find_new_name(expression.named_selects, "_w") 107 expression.select(exp.alias_(expr, alias), copy=False) 108 column = exp.column(alias) 109 110 if isinstance(expr.parent, exp.Qualify): 111 qualify_filters = column 112 else: 113 expr.replace(column) 114 elif expr.name not in expression.named_selects: 115 expression.select(expr.copy(), copy=False) 116 117 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 118 119 return expression 120 121 122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 123 """ 124 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 125 other expressions. This transforms removes the precision from parameterized types in expressions. 126 """ 127 for node in expression.find_all(exp.DataType): 128 node.set( 129 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)] 130 ) 131 132 return expression 133 134 135def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 136 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 137 if isinstance(expression, exp.Select): 138 for join in expression.args.get("joins") or []: 139 unnest = join.this 140 141 if isinstance(unnest, exp.Unnest): 142 alias = unnest.args.get("alias") 143 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 144 145 expression.args["joins"].remove(join) 146 147 for e, column in zip(unnest.expressions, alias.columns if alias else []): 148 expression.append( 149 "laterals", 150 exp.Lateral( 151 this=udtf(this=e), 152 view=True, 153 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 154 ), 155 ) 156 157 return expression 158 159 160def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 161 """Convert explode/posexplode into unnest (used in hive -> presto).""" 162 if isinstance(expression, exp.Select): 163 from sqlglot.optimizer.scope import Scope 164 165 taken_select_names = set(expression.named_selects) 166 taken_source_names = {name for name, _ in Scope(expression).references} 167 168 for select in expression.selects: 169 to_replace = select 170 171 pos_alias = "" 172 explode_alias = "" 173 174 if isinstance(select, exp.Alias): 175 explode_alias = select.alias 176 select = select.this 177 elif isinstance(select, exp.Aliases): 178 pos_alias = select.aliases[0].name 179 explode_alias = select.aliases[1].name 180 select = select.this 181 182 if isinstance(select, (exp.Explode, exp.Posexplode)): 183 is_posexplode = isinstance(select, exp.Posexplode) 184 185 explode_arg = select.this 186 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 187 188 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 189 if isinstance(explode_arg, exp.Column): 190 taken_select_names.add(explode_arg.output_name) 191 192 unnest_source_alias = find_new_name(taken_source_names, "_u") 193 taken_source_names.add(unnest_source_alias) 194 195 if not explode_alias: 196 explode_alias = find_new_name(taken_select_names, "col") 197 taken_select_names.add(explode_alias) 198 199 if is_posexplode: 200 pos_alias = find_new_name(taken_select_names, "pos") 201 taken_select_names.add(pos_alias) 202 203 if is_posexplode: 204 column_names = [explode_alias, pos_alias] 205 to_replace.pop() 206 expression.select(pos_alias, explode_alias, copy=False) 207 else: 208 column_names = [explode_alias] 209 to_replace.replace(exp.column(explode_alias)) 210 211 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 212 213 if not expression.args.get("from"): 214 expression.from_(unnest, copy=False) 215 else: 216 expression.join(unnest, join_type="CROSS", copy=False) 217 218 return expression 219 220 221def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 222 if ( 223 isinstance(expression, exp.WithinGroup) 224 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 225 and isinstance(expression.expression, exp.Order) 226 ): 227 quantile = expression.this.this 228 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 229 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 230 231 return expression 232 233 234def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 235 if isinstance(expression, exp.With) and expression.recursive: 236 next_name = name_sequence("_c_") 237 238 for cte in expression.expressions: 239 if not cte.args["alias"].columns: 240 query = cte.this 241 if isinstance(query, exp.Union): 242 query = query.this 243 244 cte.args["alias"].set( 245 "columns", 246 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 247 ) 248 249 return expression 250 251 252def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 253 if ( 254 isinstance(expression, (exp.Cast, exp.TryCast)) 255 and expression.name.lower() == "epoch" 256 and expression.to.this in exp.DataType.TEMPORAL_TYPES 257 ): 258 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 259 260 return expression 261 262 263def preprocess( 264 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 265) -> t.Callable[[Generator, exp.Expression], str]: 266 """ 267 Creates a new transform by chaining a sequence of transformations and converts the resulting 268 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 269 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 270 271 Args: 272 transforms: sequence of transform functions. These will be called in order. 273 274 Returns: 275 Function that can be used as a generator transform. 276 """ 277 278 def _to_sql(self, expression: exp.Expression) -> str: 279 expression_type = type(expression) 280 281 expression = transforms[0](expression.copy()) 282 for t in transforms[1:]: 283 expression = t(expression) 284 285 _sql_handler = getattr(self, expression.key + "_sql", None) 286 if _sql_handler: 287 return _sql_handler(expression) 288 289 transforms_handler = self.TRANSFORMS.get(type(expression)) 290 if transforms_handler: 291 # Ensures we don't enter an infinite loop. This can happen when the original expression 292 # has the same type as the final expression and there's no _sql method available for it, 293 # because then it'd re-enter _to_sql. 294 if expression_type is type(expression): 295 raise ValueError( 296 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 297 ) 298 299 return transforms_handler(self, expression) 300 301 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 302 303 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 73 window = exp.alias_(window, row_number) 74 expression.select(window, copy=False) 75 76 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 77 78 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
81def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 82 """ 83 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 84 85 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 86 https://docs.snowflake.com/en/sql-reference/constructs/qualify 87 88 Some dialects don't support window functions in the WHERE clause, so we need to include them as 89 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 90 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 91 otherwise we won't be able to refer to it in the outer query's WHERE clause. 92 """ 93 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 94 taken = set(expression.named_selects) 95 for select in expression.selects: 96 if not select.alias_or_name: 97 alias = find_new_name(taken, "_c") 98 select.replace(exp.alias_(select, alias)) 99 taken.add(alias) 100 101 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 102 qualify_filters = expression.args["qualify"].pop().this 103 104 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 105 for expr in qualify_filters.find_all(select_candidates): 106 if isinstance(expr, exp.Window): 107 alias = find_new_name(expression.named_selects, "_w") 108 expression.select(exp.alias_(expr, alias), copy=False) 109 column = exp.column(alias) 110 111 if isinstance(expr.parent, exp.Qualify): 112 qualify_filters = column 113 else: 114 expr.replace(column) 115 elif expr.name not in expression.named_selects: 116 expression.select(expr.copy(), copy=False) 117 118 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 119 120 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
123def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 124 """ 125 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 126 other expressions. This transforms removes the precision from parameterized types in expressions. 127 """ 128 for node in expression.find_all(exp.DataType): 129 node.set( 130 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)] 131 ) 132 133 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
136def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 137 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 138 if isinstance(expression, exp.Select): 139 for join in expression.args.get("joins") or []: 140 unnest = join.this 141 142 if isinstance(unnest, exp.Unnest): 143 alias = unnest.args.get("alias") 144 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 145 146 expression.args["joins"].remove(join) 147 148 for e, column in zip(unnest.expressions, alias.columns if alias else []): 149 expression.append( 150 "laterals", 151 exp.Lateral( 152 this=udtf(this=e), 153 view=True, 154 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 155 ), 156 ) 157 158 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
161def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 162 """Convert explode/posexplode into unnest (used in hive -> presto).""" 163 if isinstance(expression, exp.Select): 164 from sqlglot.optimizer.scope import Scope 165 166 taken_select_names = set(expression.named_selects) 167 taken_source_names = {name for name, _ in Scope(expression).references} 168 169 for select in expression.selects: 170 to_replace = select 171 172 pos_alias = "" 173 explode_alias = "" 174 175 if isinstance(select, exp.Alias): 176 explode_alias = select.alias 177 select = select.this 178 elif isinstance(select, exp.Aliases): 179 pos_alias = select.aliases[0].name 180 explode_alias = select.aliases[1].name 181 select = select.this 182 183 if isinstance(select, (exp.Explode, exp.Posexplode)): 184 is_posexplode = isinstance(select, exp.Posexplode) 185 186 explode_arg = select.this 187 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 188 189 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 190 if isinstance(explode_arg, exp.Column): 191 taken_select_names.add(explode_arg.output_name) 192 193 unnest_source_alias = find_new_name(taken_source_names, "_u") 194 taken_source_names.add(unnest_source_alias) 195 196 if not explode_alias: 197 explode_alias = find_new_name(taken_select_names, "col") 198 taken_select_names.add(explode_alias) 199 200 if is_posexplode: 201 pos_alias = find_new_name(taken_select_names, "pos") 202 taken_select_names.add(pos_alias) 203 204 if is_posexplode: 205 column_names = [explode_alias, pos_alias] 206 to_replace.pop() 207 expression.select(pos_alias, explode_alias, copy=False) 208 else: 209 column_names = [explode_alias] 210 to_replace.replace(exp.column(explode_alias)) 211 212 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 213 214 if not expression.args.get("from"): 215 expression.from_(unnest, copy=False) 216 else: 217 expression.join(unnest, join_type="CROSS", copy=False) 218 219 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
222def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 223 if ( 224 isinstance(expression, exp.WithinGroup) 225 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 226 and isinstance(expression.expression, exp.Order) 227 ): 228 quantile = expression.this.this 229 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 230 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 231 232 return expression
235def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 236 if isinstance(expression, exp.With) and expression.recursive: 237 next_name = name_sequence("_c_") 238 239 for cte in expression.expressions: 240 if not cte.args["alias"].columns: 241 query = cte.this 242 if isinstance(query, exp.Union): 243 query = query.this 244 245 cte.args["alias"].set( 246 "columns", 247 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 248 ) 249 250 return expression
253def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 254 if ( 255 isinstance(expression, (exp.Cast, exp.TryCast)) 256 and expression.name.lower() == "epoch" 257 and expression.to.this in exp.DataType.TEMPORAL_TYPES 258 ): 259 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 260 261 return expression
264def preprocess( 265 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 266) -> t.Callable[[Generator, exp.Expression], str]: 267 """ 268 Creates a new transform by chaining a sequence of transformations and converts the resulting 269 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 270 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 271 272 Args: 273 transforms: sequence of transform functions. These will be called in order. 274 275 Returns: 276 Function that can be used as a generator transform. 277 """ 278 279 def _to_sql(self, expression: exp.Expression) -> str: 280 expression_type = type(expression) 281 282 expression = transforms[0](expression.copy()) 283 for t in transforms[1:]: 284 expression = t(expression) 285 286 _sql_handler = getattr(self, expression.key + "_sql", None) 287 if _sql_handler: 288 return _sql_handler(expression) 289 290 transforms_handler = self.TRANSFORMS.get(type(expression)) 291 if transforms_handler: 292 # Ensures we don't enter an infinite loop. This can happen when the original expression 293 # has the same type as the final expression and there's no _sql method available for it, 294 # because then it'd re-enter _to_sql. 295 if expression_type is type(expression): 296 raise ValueError( 297 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 298 ) 299 300 return transforms_handler(self, expression) 301 302 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 303 304 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.