Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1import itertools
  2
  3from sqlglot import expressions as exp
  4from sqlglot.helper import find_new_name
  5from sqlglot.optimizer.scope import build_scope
  6
  7
  8def eliminate_subqueries(expression):
  9    """
 10    Rewrite derived tables as CTES, deduplicating if possible.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 15        >>> eliminate_subqueries(expression).sql()
 16        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 17
 18    This also deduplicates common subqueries:
 19        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
 20        >>> eliminate_subqueries(expression).sql()
 21        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
 22
 23    Args:
 24        expression (sqlglot.Expression): expression
 25    Returns:
 26        sqlglot.Expression: expression
 27    """
 28    if isinstance(expression, exp.Subquery):
 29        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 30        eliminate_subqueries(expression.this)
 31        return expression
 32
 33    root = build_scope(expression)
 34
 35    # Map of alias->Scope|Table
 36    # These are all aliases that are already used in the expression.
 37    # We don't want to create new CTEs that conflict with these names.
 38    taken = {}
 39
 40    # All CTE aliases in the root scope are taken
 41    for scope in root.cte_scopes:
 42        taken[scope.expression.parent.alias] = scope
 43
 44    # All table names are taken
 45    for scope in root.traverse():
 46        taken.update(
 47            {
 48                source.name: source
 49                for _, source in scope.sources.items()
 50                if isinstance(source, exp.Table)
 51            }
 52        )
 53
 54    # Map of Expression->alias
 55    # Existing CTES in the root expression. We'll use this for deduplication.
 56    existing_ctes = {}
 57
 58    with_ = root.expression.args.get("with")
 59    recursive = False
 60    if with_:
 61        recursive = with_.args.get("recursive")
 62        for cte in with_.expressions:
 63            existing_ctes[cte.this] = cte.alias
 64    new_ctes = []
 65
 66    # We're adding more CTEs, but we want to maintain the DAG order.
 67    # Derived tables within an existing CTE need to come before the existing CTE.
 68    for cte_scope in root.cte_scopes:
 69        # Append all the new CTEs from this existing CTE
 70        for scope in cte_scope.traverse():
 71            if scope is cte_scope:
 72                # Don't try to eliminate this CTE itself
 73                continue
 74            new_cte = _eliminate(scope, existing_ctes, taken)
 75            if new_cte:
 76                new_ctes.append(new_cte)
 77
 78        # Append the existing CTE itself
 79        new_ctes.append(cte_scope.expression.parent)
 80
 81    # Now append the rest
 82    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 83        for child_scope in scope.traverse():
 84            new_cte = _eliminate(child_scope, existing_ctes, taken)
 85            if new_cte:
 86                new_ctes.append(new_cte)
 87
 88    if new_ctes:
 89        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
 90
 91    return expression
 92
 93
 94def _eliminate(scope, existing_ctes, taken):
 95    if scope.is_union:
 96        return _eliminate_union(scope, existing_ctes, taken)
 97
 98    if scope.is_derived_table:
 99        return _eliminate_derived_table(scope, existing_ctes, taken)
100
101    if scope.is_cte:
102        return _eliminate_cte(scope, existing_ctes, taken)
103
104
105def _eliminate_union(scope, existing_ctes, taken):
106    duplicate_cte_alias = existing_ctes.get(scope.expression)
107
108    alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
109
110    taken[alias] = scope
111
112    # Try to maintain the selections
113    expressions = scope.selects
114    selects = [
115        exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
116        for e in expressions
117        if e.alias_or_name
118    ]
119    # If not all selections have an alias, just select *
120    if len(selects) != len(expressions):
121        selects = ["*"]
122
123    scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
124
125    if not duplicate_cte_alias:
126        existing_ctes[scope.expression] = alias
127        return exp.CTE(
128            this=scope.expression,
129            alias=exp.TableAlias(this=exp.to_identifier(alias)),
130        )
131
132
133def _eliminate_derived_table(scope, existing_ctes, taken):
134    parent = scope.expression.parent
135    name, cte = _new_cte(scope, existing_ctes, taken)
136
137    table = exp.alias_(exp.table_(name), alias=parent.alias or name)
138    parent.replace(table)
139
140    return cte
141
142
143def _eliminate_cte(scope, existing_ctes, taken):
144    parent = scope.expression.parent
145    name, cte = _new_cte(scope, existing_ctes, taken)
146
147    with_ = parent.parent
148    parent.pop()
149    if not with_.expressions:
150        with_.pop()
151
152    # Rename references to this CTE
153    for child_scope in scope.parent.traverse():
154        for table, source in child_scope.selected_sources.values():
155            if source is scope:
156                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
157                table.replace(new_table)
158
159    return cte
160
161
162def _new_cte(scope, existing_ctes, taken):
163    """
164    Returns:
165        tuple of (name, cte)
166        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
167        If this CTE duplicates an existing CTE, `cte` will be None.
168    """
169    duplicate_cte_alias = existing_ctes.get(scope.expression)
170    parent = scope.expression.parent
171    name = parent.alias
172
173    if not name:
174        name = find_new_name(taken=taken, base="cte")
175
176    if duplicate_cte_alias:
177        name = duplicate_cte_alias
178    elif taken.get(name):
179        name = find_new_name(taken=taken, base=name)
180
181    taken[name] = scope
182
183    if not duplicate_cte_alias:
184        existing_ctes[scope.expression] = name
185        cte = exp.CTE(
186            this=scope.expression,
187            alias=exp.TableAlias(this=exp.to_identifier(name)),
188        )
189    else:
190        cte = None
191    return name, cte
def eliminate_subqueries(expression):
 9def eliminate_subqueries(expression):
10    """
11    Rewrite derived tables as CTES, deduplicating if possible.
12
13    Example:
14        >>> import sqlglot
15        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
16        >>> eliminate_subqueries(expression).sql()
17        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
18
19    This also deduplicates common subqueries:
20        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
21        >>> eliminate_subqueries(expression).sql()
22        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
23
24    Args:
25        expression (sqlglot.Expression): expression
26    Returns:
27        sqlglot.Expression: expression
28    """
29    if isinstance(expression, exp.Subquery):
30        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
31        eliminate_subqueries(expression.this)
32        return expression
33
34    root = build_scope(expression)
35
36    # Map of alias->Scope|Table
37    # These are all aliases that are already used in the expression.
38    # We don't want to create new CTEs that conflict with these names.
39    taken = {}
40
41    # All CTE aliases in the root scope are taken
42    for scope in root.cte_scopes:
43        taken[scope.expression.parent.alias] = scope
44
45    # All table names are taken
46    for scope in root.traverse():
47        taken.update(
48            {
49                source.name: source
50                for _, source in scope.sources.items()
51                if isinstance(source, exp.Table)
52            }
53        )
54
55    # Map of Expression->alias
56    # Existing CTES in the root expression. We'll use this for deduplication.
57    existing_ctes = {}
58
59    with_ = root.expression.args.get("with")
60    recursive = False
61    if with_:
62        recursive = with_.args.get("recursive")
63        for cte in with_.expressions:
64            existing_ctes[cte.this] = cte.alias
65    new_ctes = []
66
67    # We're adding more CTEs, but we want to maintain the DAG order.
68    # Derived tables within an existing CTE need to come before the existing CTE.
69    for cte_scope in root.cte_scopes:
70        # Append all the new CTEs from this existing CTE
71        for scope in cte_scope.traverse():
72            if scope is cte_scope:
73                # Don't try to eliminate this CTE itself
74                continue
75            new_cte = _eliminate(scope, existing_ctes, taken)
76            if new_cte:
77                new_ctes.append(new_cte)
78
79        # Append the existing CTE itself
80        new_ctes.append(cte_scope.expression.parent)
81
82    # Now append the rest
83    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
84        for child_scope in scope.traverse():
85            new_cte = _eliminate(child_scope, existing_ctes, taken)
86            if new_cte:
87                new_ctes.append(new_cte)
88
89    if new_ctes:
90        expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
91
92    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Arguments:
  • expression (sqlglot.Expression): expression
Returns:

sqlglot.Expression: expression