Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1import itertools
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import build_scope
  6
  7
  8def eliminate_subqueries(expression):
  9    """
 10    Rewrite derived tables as CTES, deduplicating if possible.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 15        >>> eliminate_subqueries(expression).sql()
 16        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 17
 18    This also deduplicates common subqueries:
 19        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
 20        >>> eliminate_subqueries(expression).sql()
 21        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
 22
 23    Args:
 24        expression (sqlglot.Expression): expression
 25    Returns:
 26        sqlglot.Expression: expression
 27    """
 28    if isinstance(expression, exp.Subquery):
 29        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 30        eliminate_subqueries(expression.this)
 31        return expression
 32
 33    root = build_scope(expression)
 34
 35    if not root:
 36        return expression
 37
 38    # Map of alias->Scope|Table
 39    # These are all aliases that are already used in the expression.
 40    # We don't want to create new CTEs that conflict with these names.
 41    taken = {}
 42
 43    # All CTE aliases in the root scope are taken
 44    for scope in root.cte_scopes:
 45        taken[scope.expression.parent.alias] = scope
 46
 47    # All table names are taken
 48    for scope in root.traverse():
 49        taken.update(
 50            {
 51                source.name: source
 52                for _, source in scope.sources.items()
 53                if isinstance(source, exp.Table)
 54            }
 55        )
 56
 57    # Map of Expression->alias
 58    # Existing CTES in the root expression. We'll use this for deduplication.
 59    existing_ctes = {}
 60
 61    with_ = root.expression.args.get("with")
 62    recursive = False
 63    if with_:
 64        recursive = with_.args.get("recursive")
 65        for cte in with_.expressions:
 66            existing_ctes[cte.this] = cte.alias
 67    new_ctes = []
 68
 69    # We're adding more CTEs, but we want to maintain the DAG order.
 70    # Derived tables within an existing CTE need to come before the existing CTE.
 71    for cte_scope in root.cte_scopes:
 72        # Append all the new CTEs from this existing CTE
 73        for scope in cte_scope.traverse():
 74            if scope is cte_scope:
 75                # Don't try to eliminate this CTE itself
 76                continue
 77            new_cte = _eliminate(scope, existing_ctes, taken)
 78            if new_cte:
 79                new_ctes.append(new_cte)
 80
 81        # Append the existing CTE itself
 82        new_ctes.append(cte_scope.expression.parent)
 83
 84    # Now append the rest
 85    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 86        for child_scope in scope.traverse():
 87            new_cte = _eliminate(child_scope, existing_ctes, taken)
 88            if new_cte:
 89                new_ctes.append(new_cte)
 90
 91    if new_ctes:
 92        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
 93
 94    return expression
 95
 96
 97def _eliminate(scope, existing_ctes, taken):
 98    if scope.is_union:
 99        return _eliminate_union(scope, existing_ctes, taken)
100
101    if scope.is_derived_table:
102        return _eliminate_derived_table(scope, existing_ctes, taken)
103
104    if scope.is_cte:
105        return _eliminate_cte(scope, existing_ctes, taken)
106
107
108def _eliminate_union(scope, existing_ctes, taken):
109    duplicate_cte_alias = existing_ctes.get(scope.expression)
110
111    alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
112
113    taken[alias] = scope
114
115    # Try to maintain the selections
116    expressions = scope.expression.selects
117    selects = [
118        exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
119        for e in expressions
120        if e.alias_or_name
121    ]
122    # If not all selections have an alias, just select *
123    if len(selects) != len(expressions):
124        selects = ["*"]
125
126    scope.expression.replace(
127        exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
128    )
129
130    if not duplicate_cte_alias:
131        existing_ctes[scope.expression] = alias
132        return exp.CTE(
133            this=scope.expression,
134            alias=exp.TableAlias(this=exp.to_identifier(alias)),
135        )
136
137
138def _eliminate_derived_table(scope, existing_ctes, taken):
139    # This makes sure that we don't:
140    # - drop the "pivot" arg from a pivoted subquery
141    # - eliminate a lateral correlated subquery
142    if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
143        return None
144
145    parent = scope.expression.parent
146    name, cte = _new_cte(scope, existing_ctes, taken)
147
148    table = exp.alias_(exp.table_(name), alias=parent.alias or name)
149    table.set("joins", parent.args.get("joins"))
150
151    parent.replace(table)
152    return cte
153
154
155def _eliminate_cte(scope, existing_ctes, taken):
156    parent = scope.expression.parent
157    name, cte = _new_cte(scope, existing_ctes, taken)
158
159    with_ = parent.parent
160    parent.pop()
161    if not with_.expressions:
162        with_.pop()
163
164    # Rename references to this CTE
165    for child_scope in scope.parent.traverse():
166        for table, source in child_scope.selected_sources.values():
167            if source is scope:
168                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
169                table.replace(new_table)
170
171    return cte
172
173
174def _new_cte(scope, existing_ctes, taken):
175    """
176    Returns:
177        tuple of (name, cte)
178        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
179        If this CTE duplicates an existing CTE, `cte` will be None.
180    """
181    duplicate_cte_alias = existing_ctes.get(scope.expression)
182    parent = scope.expression.parent
183    name = parent.alias
184
185    if not name:
186        name = find_new_name(taken=taken, base="cte")
187
188    if duplicate_cte_alias:
189        name = duplicate_cte_alias
190    elif taken.get(name):
191        name = find_new_name(taken=taken, base=name)
192
193    taken[name] = scope
194
195    if not duplicate_cte_alias:
196        existing_ctes[scope.expression] = name
197        cte = exp.CTE(
198            this=scope.expression,
199            alias=exp.TableAlias(this=exp.to_identifier(name)),
200        )
201    else:
202        cte = None
203    return name, cte
def eliminate_subqueries(expression):
 9def eliminate_subqueries(expression):
10    """
11    Rewrite derived tables as CTES, deduplicating if possible.
12
13    Example:
14        >>> import sqlglot
15        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
16        >>> eliminate_subqueries(expression).sql()
17        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
18
19    This also deduplicates common subqueries:
20        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
21        >>> eliminate_subqueries(expression).sql()
22        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
23
24    Args:
25        expression (sqlglot.Expression): expression
26    Returns:
27        sqlglot.Expression: expression
28    """
29    if isinstance(expression, exp.Subquery):
30        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
31        eliminate_subqueries(expression.this)
32        return expression
33
34    root = build_scope(expression)
35
36    if not root:
37        return expression
38
39    # Map of alias->Scope|Table
40    # These are all aliases that are already used in the expression.
41    # We don't want to create new CTEs that conflict with these names.
42    taken = {}
43
44    # All CTE aliases in the root scope are taken
45    for scope in root.cte_scopes:
46        taken[scope.expression.parent.alias] = scope
47
48    # All table names are taken
49    for scope in root.traverse():
50        taken.update(
51            {
52                source.name: source
53                for _, source in scope.sources.items()
54                if isinstance(source, exp.Table)
55            }
56        )
57
58    # Map of Expression->alias
59    # Existing CTES in the root expression. We'll use this for deduplication.
60    existing_ctes = {}
61
62    with_ = root.expression.args.get("with")
63    recursive = False
64    if with_:
65        recursive = with_.args.get("recursive")
66        for cte in with_.expressions:
67            existing_ctes[cte.this] = cte.alias
68    new_ctes = []
69
70    # We're adding more CTEs, but we want to maintain the DAG order.
71    # Derived tables within an existing CTE need to come before the existing CTE.
72    for cte_scope in root.cte_scopes:
73        # Append all the new CTEs from this existing CTE
74        for scope in cte_scope.traverse():
75            if scope is cte_scope:
76                # Don't try to eliminate this CTE itself
77                continue
78            new_cte = _eliminate(scope, existing_ctes, taken)
79            if new_cte:
80                new_ctes.append(new_cte)
81
82        # Append the existing CTE itself
83        new_ctes.append(cte_scope.expression.parent)
84
85    # Now append the rest
86    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
87        for child_scope in scope.traverse():
88            new_cte = _eliminate(child_scope, existing_ctes, taken)
89            if new_cte:
90                new_ctes.append(new_cte)
91
92    if new_ctes:
93        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
94
95    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Arguments:
  • expression (sqlglot.Expression): expression
Returns:

sqlglot.Expression: expression