sqlglot.optimizer.eliminate_subqueries
1import itertools 2 3from sqlglot import expressions as exp 4from sqlglot.helper import find_new_name 5from sqlglot.optimizer.scope import build_scope 6from sqlglot.optimizer.simplify import simplify 7 8 9def eliminate_subqueries(expression): 10 """ 11 Rewrite derived tables as CTES, deduplicating if possible. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") 16 >>> eliminate_subqueries(expression).sql() 17 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' 18 19 This also deduplicates common subqueries: 20 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") 21 >>> eliminate_subqueries(expression).sql() 22 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' 23 24 Args: 25 expression (sqlglot.Expression): expression 26 Returns: 27 sqlglot.Expression: expression 28 """ 29 if isinstance(expression, exp.Subquery): 30 # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 31 eliminate_subqueries(expression.this) 32 return expression 33 34 expression = simplify(expression) 35 root = build_scope(expression) 36 37 # Map of alias->Scope|Table 38 # These are all aliases that are already used in the expression. 39 # We don't want to create new CTEs that conflict with these names. 40 taken = {} 41 42 # All CTE aliases in the root scope are taken 43 for scope in root.cte_scopes: 44 taken[scope.expression.parent.alias] = scope 45 46 # All table names are taken 47 for scope in root.traverse(): 48 taken.update( 49 { 50 source.name: source 51 for _, source in scope.sources.items() 52 if isinstance(source, exp.Table) 53 } 54 ) 55 56 # Map of Expression->alias 57 # Existing CTES in the root expression. We'll use this for deduplication. 58 existing_ctes = {} 59 60 with_ = root.expression.args.get("with") 61 recursive = False 62 if with_: 63 recursive = with_.args.get("recursive") 64 for cte in with_.expressions: 65 existing_ctes[cte.this] = cte.alias 66 new_ctes = [] 67 68 # We're adding more CTEs, but we want to maintain the DAG order. 69 # Derived tables within an existing CTE need to come before the existing CTE. 70 for cte_scope in root.cte_scopes: 71 # Append all the new CTEs from this existing CTE 72 for scope in cte_scope.traverse(): 73 if scope is cte_scope: 74 # Don't try to eliminate this CTE itself 75 continue 76 new_cte = _eliminate(scope, existing_ctes, taken) 77 if new_cte: 78 new_ctes.append(new_cte) 79 80 # Append the existing CTE itself 81 new_ctes.append(cte_scope.expression.parent) 82 83 # Now append the rest 84 for scope in itertools.chain( 85 root.union_scopes, root.subquery_scopes, root.derived_table_scopes 86 ): 87 for child_scope in scope.traverse(): 88 new_cte = _eliminate(child_scope, existing_ctes, taken) 89 if new_cte: 90 new_ctes.append(new_cte) 91 92 if new_ctes: 93 expression.set("with", exp.With(expressions=new_ctes, recursive=recursive)) 94 95 return expression 96 97 98def _eliminate(scope, existing_ctes, taken): 99 if scope.is_union: 100 return _eliminate_union(scope, existing_ctes, taken) 101 102 if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): 103 return _eliminate_derived_table(scope, existing_ctes, taken) 104 105 if scope.is_cte: 106 return _eliminate_cte(scope, existing_ctes, taken) 107 108 109def _eliminate_union(scope, existing_ctes, taken): 110 duplicate_cte_alias = existing_ctes.get(scope.expression) 111 112 alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte") 113 114 taken[alias] = scope 115 116 # Try to maintain the selections 117 expressions = scope.selects 118 selects = [ 119 exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) 120 for e in expressions 121 if e.alias_or_name 122 ] 123 # If not all selections have an alias, just select * 124 if len(selects) != len(expressions): 125 selects = ["*"] 126 127 scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias))) 128 129 if not duplicate_cte_alias: 130 existing_ctes[scope.expression] = alias 131 return exp.CTE( 132 this=scope.expression, 133 alias=exp.TableAlias(this=exp.to_identifier(alias)), 134 ) 135 136 137def _eliminate_derived_table(scope, existing_ctes, taken): 138 parent = scope.expression.parent 139 name, cte = _new_cte(scope, existing_ctes, taken) 140 141 table = exp.alias_(exp.table_(name), alias=parent.alias or name) 142 parent.replace(table) 143 144 return cte 145 146 147def _eliminate_cte(scope, existing_ctes, taken): 148 parent = scope.expression.parent 149 name, cte = _new_cte(scope, existing_ctes, taken) 150 151 with_ = parent.parent 152 parent.pop() 153 if not with_.expressions: 154 with_.pop() 155 156 # Rename references to this CTE 157 for child_scope in scope.parent.traverse(): 158 for table, source in child_scope.selected_sources.values(): 159 if source is scope: 160 new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) 161 table.replace(new_table) 162 163 return cte 164 165 166def _new_cte(scope, existing_ctes, taken): 167 """ 168 Returns: 169 tuple of (name, cte) 170 where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. 171 If this CTE duplicates an existing CTE, `cte` will be None. 172 """ 173 duplicate_cte_alias = existing_ctes.get(scope.expression) 174 parent = scope.expression.parent 175 name = parent.alias 176 177 if not name: 178 name = find_new_name(taken=taken, base="cte") 179 180 if duplicate_cte_alias: 181 name = duplicate_cte_alias 182 elif taken.get(name): 183 name = find_new_name(taken=taken, base=name) 184 185 taken[name] = scope 186 187 if not duplicate_cte_alias: 188 existing_ctes[scope.expression] = name 189 cte = exp.CTE( 190 this=scope.expression, 191 alias=exp.TableAlias(this=exp.to_identifier(name)), 192 ) 193 else: 194 cte = None 195 return name, cte
def
eliminate_subqueries(expression):
10def eliminate_subqueries(expression): 11 """ 12 Rewrite derived tables as CTES, deduplicating if possible. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") 17 >>> eliminate_subqueries(expression).sql() 18 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' 19 20 This also deduplicates common subqueries: 21 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") 22 >>> eliminate_subqueries(expression).sql() 23 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' 24 25 Args: 26 expression (sqlglot.Expression): expression 27 Returns: 28 sqlglot.Expression: expression 29 """ 30 if isinstance(expression, exp.Subquery): 31 # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 32 eliminate_subqueries(expression.this) 33 return expression 34 35 expression = simplify(expression) 36 root = build_scope(expression) 37 38 # Map of alias->Scope|Table 39 # These are all aliases that are already used in the expression. 40 # We don't want to create new CTEs that conflict with these names. 41 taken = {} 42 43 # All CTE aliases in the root scope are taken 44 for scope in root.cte_scopes: 45 taken[scope.expression.parent.alias] = scope 46 47 # All table names are taken 48 for scope in root.traverse(): 49 taken.update( 50 { 51 source.name: source 52 for _, source in scope.sources.items() 53 if isinstance(source, exp.Table) 54 } 55 ) 56 57 # Map of Expression->alias 58 # Existing CTES in the root expression. We'll use this for deduplication. 59 existing_ctes = {} 60 61 with_ = root.expression.args.get("with") 62 recursive = False 63 if with_: 64 recursive = with_.args.get("recursive") 65 for cte in with_.expressions: 66 existing_ctes[cte.this] = cte.alias 67 new_ctes = [] 68 69 # We're adding more CTEs, but we want to maintain the DAG order. 70 # Derived tables within an existing CTE need to come before the existing CTE. 71 for cte_scope in root.cte_scopes: 72 # Append all the new CTEs from this existing CTE 73 for scope in cte_scope.traverse(): 74 if scope is cte_scope: 75 # Don't try to eliminate this CTE itself 76 continue 77 new_cte = _eliminate(scope, existing_ctes, taken) 78 if new_cte: 79 new_ctes.append(new_cte) 80 81 # Append the existing CTE itself 82 new_ctes.append(cte_scope.expression.parent) 83 84 # Now append the rest 85 for scope in itertools.chain( 86 root.union_scopes, root.subquery_scopes, root.derived_table_scopes 87 ): 88 for child_scope in scope.traverse(): 89 new_cte = _eliminate(child_scope, existing_ctes, taken) 90 if new_cte: 91 new_ctes.append(new_cte) 92 93 if new_ctes: 94 expression.set("with", exp.With(expressions=new_ctes, recursive=recursive)) 95 96 return expression
Rewrite derived tables as CTES, deduplicating if possible.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") >>> eliminate_subqueries(expression).sql() 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") >>> eliminate_subqueries(expression).sql() 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Arguments:
- expression (sqlglot.Expression): expression
Returns:
sqlglot.Expression: expression