diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 162 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 144 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py (renamed from sqlglot/optimizer/merge_derived_tables.py) | 149 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/schema.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 87 |
7 files changed, 462 insertions, 88 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py new file mode 100644 index 0000000..3f5f089 --- /dev/null +++ b/sqlglot/optimizer/annotate_types.py @@ -0,0 +1,162 @@ +from sqlglot import exp +from sqlglot.helper import ensure_list, subclasses + + +def annotate_types(expression, schema=None, annotators=None, coerces_to=None): + """ + Recursively infer & annotate types in an expression syntax tree against a schema. + + (TODO -- replace this with a better example after adding some functionality) + Example: + >>> import sqlglot + >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) + >>> annotated_expression.type + <Type.DOUBLE: 'DOUBLE'> + + Args: + expression (sqlglot.Expression): Expression to annotate. + schema (dict|sqlglot.optimizer.Schema): Database schema. + annotators (dict): Maps expression type to corresponding annotation function. + coerces_to (dict): Maps expression type to set of types that it can be coerced into. + Returns: + sqlglot.Expression: expression annotated with types + """ + + return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) + + +class TypeAnnotator: + ANNOTATORS = { + **{ + expr_type: lambda self, expr: self._annotate_unary(expr) + for expr_type in subclasses(exp.__name__, exp.Unary) + }, + **{ + expr_type: lambda self, expr: self._annotate_binary(expr) + for expr_type in subclasses(exp.__name__, exp.Binary) + }, + exp.Cast: lambda self, expr: self._annotate_cast(expr), + exp.DataType: lambda self, expr: self._annotate_data_type(expr), + exp.Literal: lambda self, expr: self._annotate_literal(expr), + exp.Boolean: lambda self, expr: self._annotate_boolean(expr), + } + + # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html + COERCES_TO = { + # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT + exp.DataType.Type.TEXT: set(), + exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, + exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.CHAR: { + exp.DataType.Type.NCHAR, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.TEXT, + }, + # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE + exp.DataType.Type.DOUBLE: set(), + exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, + exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.INT: { + exp.DataType.Type.BIGINT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, + exp.DataType.Type.SMALLINT: { + exp.DataType.Type.INT, + exp.DataType.Type.BIGINT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, + exp.DataType.Type.TINYINT: { + exp.DataType.Type.SMALLINT, + exp.DataType.Type.INT, + exp.DataType.Type.BIGINT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, + # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ + exp.DataType.Type.TIMESTAMPLTZ: set(), + exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.DATETIME: { + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, + exp.DataType.Type.DATE: { + exp.DataType.Type.DATETIME, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, + } + + def __init__(self, schema=None, annotators=None, coerces_to=None): + self.schema = schema + self.annotators = annotators or self.ANNOTATORS + self.coerces_to = coerces_to or self.COERCES_TO + + def annotate(self, expression): + if not isinstance(expression, exp.Expression): + return None + + annotator = self.annotators.get(expression.__class__) + return annotator(self, expression) if annotator else self._annotate_args(expression) + + def _annotate_args(self, expression): + for value in expression.args.values(): + for v in ensure_list(value): + self.annotate(v) + + return expression + + def _annotate_cast(self, expression): + expression.type = expression.args["to"].this + return self._annotate_args(expression) + + def _annotate_data_type(self, expression): + expression.type = expression.this + return self._annotate_args(expression) + + def _maybe_coerce(self, type1, type2): + return type2 if type2 in self.coerces_to[type1] else type1 + + def _annotate_binary(self, expression): + self._annotate_args(expression) + + if isinstance(expression, (exp.Condition, exp.Predicate)): + expression.type = exp.DataType.Type.BOOLEAN + else: + expression.type = self._maybe_coerce(expression.left.type, expression.right.type) + + return expression + + def _annotate_unary(self, expression): + self._annotate_args(expression) + + if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): + expression.type = exp.DataType.Type.BOOLEAN + else: + expression.type = expression.this.type + + return expression + + def _annotate_literal(self, expression): + if expression.is_string: + expression.type = exp.DataType.Type.VARCHAR + elif expression.is_int: + expression.type = exp.DataType.Type.INT + else: + expression.type = exp.DataType.Type.DOUBLE + + return expression + + def _annotate_boolean(self, expression): + expression.type = exp.DataType.Type.BOOLEAN + return expression diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 4bfb733..38e1299 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -1,48 +1,144 @@ import itertools -from sqlglot import alias, exp, select, table -from sqlglot.optimizer.scope import traverse_scope +from sqlglot import expressions as exp +from sqlglot.helper import find_new_name +from sqlglot.optimizer.scope import build_scope from sqlglot.optimizer.simplify import simplify def eliminate_subqueries(expression): """ - Rewrite duplicate subqueries from sqlglot AST. + Rewrite subqueries as CTES, deduplicating if possible. Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") >>> eliminate_subqueries(expression).sql() - 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0' + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' + + This also deduplicates common subqueries: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") + >>> eliminate_subqueries(expression).sql() + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' Args: - expression (sqlglot.Expression): expression to qualify - schema (dict|sqlglot.optimizer.Schema): Database schema + expression (sqlglot.Expression): expression Returns: - sqlglot.Expression: qualified expression + sqlglot.Expression: expression """ + if isinstance(expression, exp.Subquery): + # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 + eliminate_subqueries(expression.this) + return expression + expression = simplify(expression) - queries = {} + root = build_scope(expression) + + # Map of alias->Scope|Table + # These are all aliases that are already used in the expression. + # We don't want to create new CTEs that conflict with these names. + taken = {} + + # All CTE aliases in the root scope are taken + for scope in root.cte_scopes: + taken[scope.expression.parent.alias] = scope + + # All table names are taken + for scope in root.traverse(): + taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)}) - for scope in traverse_scope(expression): - query = scope.expression - queries[query] = queries.get(query, []) + [query] + # Map of Expression->alias + # Existing CTES in the root expression. We'll use this for deduplication. + existing_ctes = {} - sequence = itertools.count() + with_ = root.expression.args.get("with") + if with_: + for cte in with_.expressions: + existing_ctes[cte.this] = cte.alias + new_ctes = [] - for query, duplicates in queries.items(): - if len(duplicates) == 1: - continue + # We're adding more CTEs, but we want to maintain the DAG order. + # Derived tables within an existing CTE need to come before the existing CTE. + for cte_scope in root.cte_scopes: + # Append all the new CTEs from this existing CTE + for scope in cte_scope.traverse(): + new_cte = _eliminate(scope, existing_ctes, taken) + if new_cte: + new_ctes.append(new_cte) - alias_ = f"_e_{next(sequence)}" + # Append the existing CTE itself + new_ctes.append(cte_scope.expression.parent) - for dup in duplicates: - parent = dup.parent - if isinstance(parent, exp.Subquery): - parent.replace(alias(table(alias_), parent.alias_or_name, table=True)) - elif isinstance(parent, exp.Union): - dup.replace(select("*").from_(alias_)) + # Now append the rest + for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes): + for child_scope in scope.traverse(): + new_cte = _eliminate(child_scope, existing_ctes, taken) + if new_cte: + new_ctes.append(new_cte) - expression.with_(alias_, as_=query, copy=False) + if new_ctes: + expression.set("with", exp.With(expressions=new_ctes)) return expression + + +def _eliminate(scope, existing_ctes, taken): + if scope.is_union: + return _eliminate_union(scope, existing_ctes, taken) + + if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)): + return _eliminate_derived_table(scope, existing_ctes, taken) + + +def _eliminate_union(scope, existing_ctes, taken): + duplicate_cte_alias = existing_ctes.get(scope.expression) + + alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte") + + taken[alias] = scope + + # Try to maintain the selections + expressions = scope.expression.args.get("expressions") + selects = [ + exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) + for e in expressions + if e.alias_or_name + ] + # If not all selections have an alias, just select * + if len(selects) != len(expressions): + selects = ["*"] + + scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias))) + + if not duplicate_cte_alias: + existing_ctes[scope.expression] = alias + return exp.CTE( + this=scope.expression, + alias=exp.TableAlias(this=exp.to_identifier(alias)), + ) + + +def _eliminate_derived_table(scope, existing_ctes, taken): + duplicate_cte_alias = existing_ctes.get(scope.expression) + parent = scope.expression.parent + name = alias = parent.alias + + if not alias: + name = alias = find_new_name(taken=taken, base="cte") + + if duplicate_cte_alias: + name = duplicate_cte_alias + elif taken.get(alias): + name = find_new_name(taken=taken, base=alias) + + taken[name] = scope + + table = exp.alias_(exp.table_(name), alias=alias) + parent.replace(table) + + if not duplicate_cte_alias: + existing_ctes[scope.expression] = name + return exp.CTE( + this=scope.expression, + alias=exp.TableAlias(this=exp.to_identifier(name)), + ) diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_subqueries.py index 8b161fb..9d966b7 100644 --- a/sqlglot/optimizer/merge_derived_tables.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -1,72 +1,127 @@ from collections import defaultdict from sqlglot import expressions as exp -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.helper import find_new_name +from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.simplify import simplify -def merge_derived_tables(expression): +def merge_subqueries(expression, leave_tables_isolated=False): """ Rewrite sqlglot AST to merge derived tables into the outer query. + This also merges CTEs if they are selected from only once. + Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)") - >>> merge_derived_tables(expression).sql() - 'SELECT x.a FROM x' + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression).sql() + 'SELECT x.a FROM x JOIN y' + + If `leave_tables_isolated` is True, this will not merge inner queries into outer + queries if it would result in multiple table selects in a single query: + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> merge_subqueries(expression, leave_tables_isolated=True).sql() + 'SELECT a FROM (SELECT x.a FROM x) JOIN y' Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html Args: expression (sqlglot.Expression): expression to optimize + leave_tables_isolated (bool): Returns: sqlglot.Expression: optimized expression """ + merge_ctes(expression, leave_tables_isolated) + merge_derived_tables(expression, leave_tables_isolated) + return expression + + +# If a derived table has these Select args, it can't be merged +UNMERGABLE_ARGS = set(exp.Select.arg_types) - { + "expressions", + "from", + "joins", + "where", + "order", +} + + +def merge_ctes(expression, leave_tables_isolated=False): + scopes = traverse_scope(expression) + + # All places where we select from CTEs. + # We key on the CTE scope so we can detect CTES that are selected from multiple times. + cte_selections = defaultdict(list) + for outer_scope in scopes: + for table, inner_scope in outer_scope.selected_sources.values(): + if isinstance(inner_scope, Scope) and inner_scope.is_cte: + cte_selections[id(inner_scope)].append( + ( + outer_scope, + inner_scope, + table, + ) + ) + + singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] + for outer_scope, inner_scope, table in singular_cte_selections: + inner_select = inner_scope.expression.unnest() + if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = table.find_ancestor(exp.From, exp.Join) + + node_to_replace = table + if isinstance(node_to_replace.parent, exp.Alias): + node_to_replace = node_to_replace.parent + alias = node_to_replace.alias + else: + alias = table.name + + _rename_inner_sources(outer_scope, inner_scope, alias) + _merge_from(outer_scope, inner_scope, node_to_replace, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) + _merge_expressions(outer_scope, inner_scope, alias) + _merge_where(outer_scope, inner_scope, from_or_join) + _merge_order(outer_scope, inner_scope) + _pop_cte(inner_scope) + + +def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if ( - isinstance(outer_scope.expression, exp.Select) - and isinstance(inner_select, exp.Select) - and _mergeable(inner_select) - ): + if _mergeable(outer_scope, inner_select, leave_tables_isolated): alias = subquery.alias_or_name from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, subquery) + _merge_from(outer_scope, inner_scope, subquery, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_expressions(outer_scope, inner_scope, alias) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) - return expression -# If a derived table has these Select args, it can't be merged -UNMERGABLE_ARGS = set(exp.Select.arg_types) - { - "expressions", - "from", - "joins", - "where", - "order", -} - - -def _mergeable(inner_select): +def _mergeable(outer_scope, inner_select, leave_tables_isolated): """ Return True if `inner_select` can be merged into outer query. Args: + outer_scope (Scope) inner_select (exp.Select) + leave_tables_isolated (bool) Returns: bool: True if can be merged """ return ( - isinstance(inner_select, exp.Select) + isinstance(outer_scope.expression, exp.Select) + and isinstance(inner_select, exp.Select) + and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) ) @@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): conflicts = conflicts - {alias} for conflict in conflicts: - new_name = _find_new_name(taken, conflict) + new_name = find_new_name(taken, conflict) source, _ = inner_scope.selected_sources[conflict] new_alias = exp.to_identifier(new_name) @@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): inner_scope.rename_source(conflict, new_name) -def _find_new_name(taken, base): - """ - Searches for a new source name. - - Args: - taken (set[str]): set of taken names - base (str): base name to alter - """ - i = 2 - new = f"{base}_{i}" - while new in taken: - i += 1 - new = f"{base}_{i}" - return new - - -def _merge_from(outer_scope, inner_scope, subquery): +def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ Merge FROM clause of inner query into outer query. Args: outer_scope (sqlglot.optimizer.scope.Scope) inner_scope (sqlglot.optimizer.scope.Scope) - subquery (exp.Subquery) + node_to_replace (exp.Subquery|exp.Table) + alias (str) """ new_subquery = inner_scope.expression.args.get("from").expressions[0] - subquery.replace(new_subquery) - outer_scope.remove_source(subquery.alias_or_name) + node_to_replace.replace(new_subquery) + outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias): inner_scope (sqlglot.optimizer.scope.Scope) alias (str) """ - # Collect all columns that for the alias of the inner query + # Collect all columns that reference the alias of the inner query outer_columns = defaultdict(list) for column in outer_scope.columns: if column.table == alias: @@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if not where or not where.this: return - if isinstance(from_or_join, exp.Join) and from_or_join.side: + if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause from_or_join.on(where.this, copy=False) from_or_join.set("on", simplify(from_or_join.args.get("on"))) @@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope): return outer_scope.expression.set("order", inner_scope.expression.args.get("order")) + + +def _pop_cte(inner_scope): + """ + Remove CTE from the AST. + + Args: + inner_scope (sqlglot.optimizer.scope.Scope) + """ + cte = inner_scope.expression.parent + with_ = cte.parent + if len(with_.expressions) == 1: + with_.pop() + else: + cte.pop() diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c8c2403..9a09327 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,7 +1,7 @@ from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects -from sqlglot.optimizer.merge_derived_tables import merge_derived_tables +from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates @@ -22,7 +22,7 @@ RULES = ( pushdown_predicates, optimize_joins, eliminate_subqueries, - merge_derived_tables, + merge_subqueries, quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 097ce04..5584830 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -37,7 +37,7 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left, right = scope.union + left, right = scope.union_scopes referenced_columns[left] = parent_selections referenced_columns[right] = parent_selections diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 1761228..1bbd86a 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -69,7 +69,7 @@ def ensure_schema(schema): def fs_get(table): - name = table.this.name.upper() + name = table.this.name if name.upper() == "READ_CSV": with csv_reader(table) as reader: diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index e816e10..be6cfb9 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,3 +1,4 @@ +import itertools from copy import copy from enum import Enum, auto @@ -32,10 +33,11 @@ class Scope: The inner query would have `["col1", "col2"]` for its `outer_column_list` parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent - subquery_scopes (list[Scope]): List of all child scopes for subqueries. - This does not include derived tables or CTEs. - union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be - a tuple of the left and right child scopes. + subquery_scopes (list[Scope]): List of all child scopes for subqueries + cte_scopes = (list[Scope]) List of all child scopes for CTEs + derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables + union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be + a list of the left and right child scopes. """ def __init__( @@ -52,7 +54,9 @@ class Scope: self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] - self.union = None + self.derived_table_scopes = [] + self.cte_scopes = [] + self.union_scopes = [] self.clear_cache() def clear_cache(self): @@ -197,11 +201,16 @@ class Scope: named_outputs = {e.alias_or_name for e in self.expression.expressions} - self._columns = [ - c - for c in columns + external_columns - if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs) - ] + self._columns = [] + for column in columns + external_columns: + ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint) + if ( + not ancestor + or column.table + or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint)) + ): + self._columns.append(column) + return self._columns @property @@ -284,6 +293,26 @@ class Scope: return self.scope_type == ScopeType.SUBQUERY @property + def is_derived_table(self): + """Determine if this scope is a derived table""" + return self.scope_type == ScopeType.DERIVED_TABLE + + @property + def is_union(self): + """Determine if this scope is a union""" + return self.scope_type == ScopeType.UNION + + @property + def is_cte(self): + """Determine if this scope is a common table expression""" + return self.scope_type == ScopeType.CTE + + @property + def is_root(self): + """Determine if this is the root scope""" + return self.scope_type == ScopeType.ROOT + + @property def is_unnest(self): """Determine if this scope is an unnest""" return self.scope_type == ScopeType.UNNEST @@ -308,6 +337,22 @@ class Scope: self.sources.pop(name, None) self.clear_cache() + def __repr__(self): + return f"Scope<{self.expression.sql()}>" + + def traverse(self): + """ + Traverse the scope tree from this node. + + Yields: + Scope: scope instances in depth-first-search post-order + """ + for child_scope in itertools.chain( + self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes + ): + yield from child_scope.traverse() + yield self + def traverse_scope(expression): """ @@ -337,6 +382,18 @@ def traverse_scope(expression): return list(_traverse_scope(Scope(expression))) +def build_scope(expression): + """ + Build a scope tree. + + Args: + expression (exp.Expression): expression to build the scope tree for + Returns: + Scope: root scope + """ + return traverse_scope(expression)[-1] + + def _traverse_scope(scope): if isinstance(scope.expression, exp.Select): yield from _traverse_select(scope) @@ -370,13 +427,14 @@ def _traverse_union(scope): for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): yield right - scope.union = (left, right) + scope.union_scopes = [left, right] def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} for derived_table in derived_tables: + top = None for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, @@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope + top = child_scope # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. sources[derived_table.alias] = child_scope + if scope_type == ScopeType.CTE: + scope.cte_scopes.append(top) + else: + scope.derived_table_scopes.append(top) scope.sources.update(sources) @@ -407,8 +470,6 @@ def _add_table_sources(scope): if table_name in scope.sources: # This is a reference to a parent source (e.g. a CTE), not an actual table. scope.sources[source_name] = scope.sources[table_name] - elif source_name in scope.sources: - raise OptimizeError(f"Duplicate table name: {source_name}") else: sources[source_name] = table |