Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1import itertools
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import build_scope
  6from sqlglot.optimizer.simplify import simplify
  7
  8
  9def eliminate_subqueries(expression):
 10    """
 11    Rewrite derived tables as CTES, deduplicating if possible.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 16        >>> eliminate_subqueries(expression).sql()
 17        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 18
 19    This also deduplicates common subqueries:
 20        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
 21        >>> eliminate_subqueries(expression).sql()
 22        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
 23
 24    Args:
 25        expression (sqlglot.Expression): expression
 26    Returns:
 27        sqlglot.Expression: expression
 28    """
 29    if isinstance(expression, exp.Subquery):
 30        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 31        eliminate_subqueries(expression.this)
 32        return expression
 33
 34    expression = simplify(expression)
 35    root = build_scope(expression)
 36
 37    # Map of alias->Scope|Table
 38    # These are all aliases that are already used in the expression.
 39    # We don't want to create new CTEs that conflict with these names.
 40    taken = {}
 41
 42    # All CTE aliases in the root scope are taken
 43    for scope in root.cte_scopes:
 44        taken[scope.expression.parent.alias] = scope
 45
 46    # All table names are taken
 47    for scope in root.traverse():
 48        taken.update(
 49            {
 50                source.name: source
 51                for _, source in scope.sources.items()
 52                if isinstance(source, exp.Table)
 53            }
 54        )
 55
 56    # Map of Expression->alias
 57    # Existing CTES in the root expression. We'll use this for deduplication.
 58    existing_ctes = {}
 59
 60    with_ = root.expression.args.get("with")
 61    recursive = False
 62    if with_:
 63        recursive = with_.args.get("recursive")
 64        for cte in with_.expressions:
 65            existing_ctes[cte.this] = cte.alias
 66    new_ctes = []
 67
 68    # We're adding more CTEs, but we want to maintain the DAG order.
 69    # Derived tables within an existing CTE need to come before the existing CTE.
 70    for cte_scope in root.cte_scopes:
 71        # Append all the new CTEs from this existing CTE
 72        for scope in cte_scope.traverse():
 73            if scope is cte_scope:
 74                # Don't try to eliminate this CTE itself
 75                continue
 76            new_cte = _eliminate(scope, existing_ctes, taken)
 77            if new_cte:
 78                new_ctes.append(new_cte)
 79
 80        # Append the existing CTE itself
 81        new_ctes.append(cte_scope.expression.parent)
 82
 83    # Now append the rest
 84    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 85        for child_scope in scope.traverse():
 86            new_cte = _eliminate(child_scope, existing_ctes, taken)
 87            if new_cte:
 88                new_ctes.append(new_cte)
 89
 90    if new_ctes:
 91        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
 92
 93    return expression
 94
 95
 96def _eliminate(scope, existing_ctes, taken):
 97    if scope.is_union:
 98        return _eliminate_union(scope, existing_ctes, taken)
 99
100    if scope.is_derived_table:
101        return _eliminate_derived_table(scope, existing_ctes, taken)
102
103    if scope.is_cte:
104        return _eliminate_cte(scope, existing_ctes, taken)
105
106
107def _eliminate_union(scope, existing_ctes, taken):
108    duplicate_cte_alias = existing_ctes.get(scope.expression)
109
110    alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
111
112    taken[alias] = scope
113
114    # Try to maintain the selections
115    expressions = scope.selects
116    selects = [
117        exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
118        for e in expressions
119        if e.alias_or_name
120    ]
121    # If not all selections have an alias, just select *
122    if len(selects) != len(expressions):
123        selects = ["*"]
124
125    scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
126
127    if not duplicate_cte_alias:
128        existing_ctes[scope.expression] = alias
129        return exp.CTE(
130            this=scope.expression,
131            alias=exp.TableAlias(this=exp.to_identifier(alias)),
132        )
133
134
135def _eliminate_derived_table(scope, existing_ctes, taken):
136    parent = scope.expression.parent
137    name, cte = _new_cte(scope, existing_ctes, taken)
138
139    table = exp.alias_(exp.table_(name), alias=parent.alias or name)
140    parent.replace(table)
141
142    return cte
143
144
145def _eliminate_cte(scope, existing_ctes, taken):
146    parent = scope.expression.parent
147    name, cte = _new_cte(scope, existing_ctes, taken)
148
149    with_ = parent.parent
150    parent.pop()
151    if not with_.expressions:
152        with_.pop()
153
154    # Rename references to this CTE
155    for child_scope in scope.parent.traverse():
156        for table, source in child_scope.selected_sources.values():
157            if source is scope:
158                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
159                table.replace(new_table)
160
161    return cte
162
163
164def _new_cte(scope, existing_ctes, taken):
165    """
166    Returns:
167        tuple of (name, cte)
168        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
169        If this CTE duplicates an existing CTE, `cte` will be None.
170    """
171    duplicate_cte_alias = existing_ctes.get(scope.expression)
172    parent = scope.expression.parent
173    name = parent.alias
174
175    if not name:
176        name = find_new_name(taken=taken, base="cte")
177
178    if duplicate_cte_alias:
179        name = duplicate_cte_alias
180    elif taken.get(name):
181        name = find_new_name(taken=taken, base=name)
182
183    taken[name] = scope
184
185    if not duplicate_cte_alias:
186        existing_ctes[scope.expression] = name
187        cte = exp.CTE(
188            this=scope.expression,
189            alias=exp.TableAlias(this=exp.to_identifier(name)),
190        )
191    else:
192        cte = None
193    return name, cte
def eliminate_subqueries(expression):
10def eliminate_subqueries(expression):
11    """
12    Rewrite derived tables as CTES, deduplicating if possible.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
17        >>> eliminate_subqueries(expression).sql()
18        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
19
20    This also deduplicates common subqueries:
21        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
22        >>> eliminate_subqueries(expression).sql()
23        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
24
25    Args:
26        expression (sqlglot.Expression): expression
27    Returns:
28        sqlglot.Expression: expression
29    """
30    if isinstance(expression, exp.Subquery):
31        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
32        eliminate_subqueries(expression.this)
33        return expression
34
35    expression = simplify(expression)
36    root = build_scope(expression)
37
38    # Map of alias->Scope|Table
39    # These are all aliases that are already used in the expression.
40    # We don't want to create new CTEs that conflict with these names.
41    taken = {}
42
43    # All CTE aliases in the root scope are taken
44    for scope in root.cte_scopes:
45        taken[scope.expression.parent.alias] = scope
46
47    # All table names are taken
48    for scope in root.traverse():
49        taken.update(
50            {
51                source.name: source
52                for _, source in scope.sources.items()
53                if isinstance(source, exp.Table)
54            }
55        )
56
57    # Map of Expression->alias
58    # Existing CTES in the root expression. We'll use this for deduplication.
59    existing_ctes = {}
60
61    with_ = root.expression.args.get("with")
62    recursive = False
63    if with_:
64        recursive = with_.args.get("recursive")
65        for cte in with_.expressions:
66            existing_ctes[cte.this] = cte.alias
67    new_ctes = []
68
69    # We're adding more CTEs, but we want to maintain the DAG order.
70    # Derived tables within an existing CTE need to come before the existing CTE.
71    for cte_scope in root.cte_scopes:
72        # Append all the new CTEs from this existing CTE
73        for scope in cte_scope.traverse():
74            if scope is cte_scope:
75                # Don't try to eliminate this CTE itself
76                continue
77            new_cte = _eliminate(scope, existing_ctes, taken)
78            if new_cte:
79                new_ctes.append(new_cte)
80
81        # Append the existing CTE itself
82        new_ctes.append(cte_scope.expression.parent)
83
84    # Now append the rest
85    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
86        for child_scope in scope.traverse():
87            new_cte = _eliminate(child_scope, existing_ctes, taken)
88            if new_cte:
89                new_ctes.append(new_cte)
90
91    if new_ctes:
92        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
93
94    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Arguments:
  • expression (sqlglot.Expression): expression
Returns:

sqlglot.Expression: expression