Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1import itertools
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import build_scope
  6from sqlglot.optimizer.simplify import simplify
  7
  8
  9def eliminate_subqueries(expression):
 10    """
 11    Rewrite derived tables as CTES, deduplicating if possible.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 16        >>> eliminate_subqueries(expression).sql()
 17        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 18
 19    This also deduplicates common subqueries:
 20        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
 21        >>> eliminate_subqueries(expression).sql()
 22        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
 23
 24    Args:
 25        expression (sqlglot.Expression): expression
 26    Returns:
 27        sqlglot.Expression: expression
 28    """
 29    if isinstance(expression, exp.Subquery):
 30        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 31        eliminate_subqueries(expression.this)
 32        return expression
 33
 34    expression = simplify(expression)
 35    root = build_scope(expression)
 36
 37    # Map of alias->Scope|Table
 38    # These are all aliases that are already used in the expression.
 39    # We don't want to create new CTEs that conflict with these names.
 40    taken = {}
 41
 42    # All CTE aliases in the root scope are taken
 43    for scope in root.cte_scopes:
 44        taken[scope.expression.parent.alias] = scope
 45
 46    # All table names are taken
 47    for scope in root.traverse():
 48        taken.update(
 49            {
 50                source.name: source
 51                for _, source in scope.sources.items()
 52                if isinstance(source, exp.Table)
 53            }
 54        )
 55
 56    # Map of Expression->alias
 57    # Existing CTES in the root expression. We'll use this for deduplication.
 58    existing_ctes = {}
 59
 60    with_ = root.expression.args.get("with")
 61    recursive = False
 62    if with_:
 63        recursive = with_.args.get("recursive")
 64        for cte in with_.expressions:
 65            existing_ctes[cte.this] = cte.alias
 66    new_ctes = []
 67
 68    # We're adding more CTEs, but we want to maintain the DAG order.
 69    # Derived tables within an existing CTE need to come before the existing CTE.
 70    for cte_scope in root.cte_scopes:
 71        # Append all the new CTEs from this existing CTE
 72        for scope in cte_scope.traverse():
 73            if scope is cte_scope:
 74                # Don't try to eliminate this CTE itself
 75                continue
 76            new_cte = _eliminate(scope, existing_ctes, taken)
 77            if new_cte:
 78                new_ctes.append(new_cte)
 79
 80        # Append the existing CTE itself
 81        new_ctes.append(cte_scope.expression.parent)
 82
 83    # Now append the rest
 84    for scope in itertools.chain(
 85        root.union_scopes, root.subquery_scopes, root.derived_table_scopes
 86    ):
 87        for child_scope in scope.traverse():
 88            new_cte = _eliminate(child_scope, existing_ctes, taken)
 89            if new_cte:
 90                new_ctes.append(new_cte)
 91
 92    if new_ctes:
 93        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
 94
 95    return expression
 96
 97
 98def _eliminate(scope, existing_ctes, taken):
 99    if scope.is_union:
100        return _eliminate_union(scope, existing_ctes, taken)
101
102    if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
103        return _eliminate_derived_table(scope, existing_ctes, taken)
104
105    if scope.is_cte:
106        return _eliminate_cte(scope, existing_ctes, taken)
107
108
109def _eliminate_union(scope, existing_ctes, taken):
110    duplicate_cte_alias = existing_ctes.get(scope.expression)
111
112    alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
113
114    taken[alias] = scope
115
116    # Try to maintain the selections
117    expressions = scope.selects
118    selects = [
119        exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
120        for e in expressions
121        if e.alias_or_name
122    ]
123    # If not all selections have an alias, just select *
124    if len(selects) != len(expressions):
125        selects = ["*"]
126
127    scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
128
129    if not duplicate_cte_alias:
130        existing_ctes[scope.expression] = alias
131        return exp.CTE(
132            this=scope.expression,
133            alias=exp.TableAlias(this=exp.to_identifier(alias)),
134        )
135
136
137def _eliminate_derived_table(scope, existing_ctes, taken):
138    parent = scope.expression.parent
139    name, cte = _new_cte(scope, existing_ctes, taken)
140
141    table = exp.alias_(exp.table_(name), alias=parent.alias or name)
142    parent.replace(table)
143
144    return cte
145
146
147def _eliminate_cte(scope, existing_ctes, taken):
148    parent = scope.expression.parent
149    name, cte = _new_cte(scope, existing_ctes, taken)
150
151    with_ = parent.parent
152    parent.pop()
153    if not with_.expressions:
154        with_.pop()
155
156    # Rename references to this CTE
157    for child_scope in scope.parent.traverse():
158        for table, source in child_scope.selected_sources.values():
159            if source is scope:
160                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
161                table.replace(new_table)
162
163    return cte
164
165
166def _new_cte(scope, existing_ctes, taken):
167    """
168    Returns:
169        tuple of (name, cte)
170        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
171        If this CTE duplicates an existing CTE, `cte` will be None.
172    """
173    duplicate_cte_alias = existing_ctes.get(scope.expression)
174    parent = scope.expression.parent
175    name = parent.alias
176
177    if not name:
178        name = find_new_name(taken=taken, base="cte")
179
180    if duplicate_cte_alias:
181        name = duplicate_cte_alias
182    elif taken.get(name):
183        name = find_new_name(taken=taken, base=name)
184
185    taken[name] = scope
186
187    if not duplicate_cte_alias:
188        existing_ctes[scope.expression] = name
189        cte = exp.CTE(
190            this=scope.expression,
191            alias=exp.TableAlias(this=exp.to_identifier(name)),
192        )
193    else:
194        cte = None
195    return name, cte
def eliminate_subqueries(expression):
10def eliminate_subqueries(expression):
11    """
12    Rewrite derived tables as CTES, deduplicating if possible.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
17        >>> eliminate_subqueries(expression).sql()
18        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
19
20    This also deduplicates common subqueries:
21        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
22        >>> eliminate_subqueries(expression).sql()
23        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
24
25    Args:
26        expression (sqlglot.Expression): expression
27    Returns:
28        sqlglot.Expression: expression
29    """
30    if isinstance(expression, exp.Subquery):
31        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
32        eliminate_subqueries(expression.this)
33        return expression
34
35    expression = simplify(expression)
36    root = build_scope(expression)
37
38    # Map of alias->Scope|Table
39    # These are all aliases that are already used in the expression.
40    # We don't want to create new CTEs that conflict with these names.
41    taken = {}
42
43    # All CTE aliases in the root scope are taken
44    for scope in root.cte_scopes:
45        taken[scope.expression.parent.alias] = scope
46
47    # All table names are taken
48    for scope in root.traverse():
49        taken.update(
50            {
51                source.name: source
52                for _, source in scope.sources.items()
53                if isinstance(source, exp.Table)
54            }
55        )
56
57    # Map of Expression->alias
58    # Existing CTES in the root expression. We'll use this for deduplication.
59    existing_ctes = {}
60
61    with_ = root.expression.args.get("with")
62    recursive = False
63    if with_:
64        recursive = with_.args.get("recursive")
65        for cte in with_.expressions:
66            existing_ctes[cte.this] = cte.alias
67    new_ctes = []
68
69    # We're adding more CTEs, but we want to maintain the DAG order.
70    # Derived tables within an existing CTE need to come before the existing CTE.
71    for cte_scope in root.cte_scopes:
72        # Append all the new CTEs from this existing CTE
73        for scope in cte_scope.traverse():
74            if scope is cte_scope:
75                # Don't try to eliminate this CTE itself
76                continue
77            new_cte = _eliminate(scope, existing_ctes, taken)
78            if new_cte:
79                new_ctes.append(new_cte)
80
81        # Append the existing CTE itself
82        new_ctes.append(cte_scope.expression.parent)
83
84    # Now append the rest
85    for scope in itertools.chain(
86        root.union_scopes, root.subquery_scopes, root.derived_table_scopes
87    ):
88        for child_scope in scope.traverse():
89            new_cte = _eliminate(child_scope, existing_ctes, taken)
90            if new_cte:
91                new_ctes.append(new_cte)
92
93    if new_ctes:
94        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
95
96    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Arguments:
  • expression (sqlglot.Expression): expression
Returns:

sqlglot.Expression: expression