diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:26 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:26 +0000 |
commit | 8b4272814fb4585be120f183eb7c26bb8acde974 (patch) | |
tree | 85d56a8f5ac4ac94ab924d5bbc578586eeb2a998 /sqlglot/optimizer | |
parent | Releasing debian version 7.1.3-1. (diff) | |
download | sqlglot-8b4272814fb4585be120f183eb7c26bb8acde974.tar.xz sqlglot-8b4272814fb4585be120f183eb7c26bb8acde974.zip |
Merging upstream version 9.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/isolate_table_selects.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 16 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 35 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 27 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/schema.py | 180 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 11 |
11 files changed, 69 insertions, 222 deletions
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index d1146ca..bba0878 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,2 +1 @@ from sqlglot.optimizer.optimizer import RULES, optimize -from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index a2cef37..30055bc 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,7 +1,7 @@ from sqlglot import exp from sqlglot.helper import ensure_list, subclasses -from sqlglot.optimizer.schema import ensure_schema from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema def annotate_types(expression, schema=None, annotators=None, coerces_to=None): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 44cdc94..e30c263 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_union: return _eliminate_union(scope, existing_ctes, taken) - if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)): + if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index e060739..652cdef 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -12,18 +12,16 @@ def isolate_table_selects(expression): if not isinstance(source, exp.Table): continue - if not isinstance(source.parent, exp.Alias): + if not source.alias: raise OptimizeError("Tables require an alias. Run qualify_tables optimization.") - parent = source.parent - - parent.replace( + source.replace( exp.select("*") .from_( - alias(source, source.name or parent.alias, table=True), + alias(source.copy(), source.name or source.alias, table=True), copy=False, ) - .subquery(parent.alias, copy=False) + .subquery(source.alias, copy=False) ) return expression diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 3c51c18..70e4629 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False): inner_select = inner_scope.expression.unnest() from_or_join = table.find_ancestor(exp.From, exp.Join) if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): - node_to_replace = table - if isinstance(node_to_replace.parent, exp.Alias): - node_to_replace = node_to_replace.parent - alias = node_to_replace.alias - else: - alias = table.name + alias = table.alias_or_name _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, node_to_replace, alias) + _merge_from(outer_scope, inner_scope, table, alias) _merge_expressions(outer_scope, inner_scope, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) @@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): if isinstance(source, exp.Subquery): source.set("alias", exp.TableAlias(this=new_alias)) - elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): - source.parent.set("alias", new_alias) + elif isinstance(source, exp.Table) and source.alias: + source.set("alias", new_alias) elif isinstance(source, exp.Table): source.replace(exp.alias_(source.copy(), new_alias)) @@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): tables = join_hint.find_all(exp.Table) for table in tables: if table.alias_or_name == node_to_replace.alias_or_name: - new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery - table.set("this", exp.to_identifier(new_table.alias_or_name)) + table.set("this", exp.to_identifier(new_subquery.alias_or_name)) outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 2c28ab8..5ad8f46 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,3 +1,4 @@ +import sqlglot from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries @@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} + If no schema is provided then the default schema defined at `sqlgot.schema` will be used db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement rules (list): sequence of optimizer rules to use @@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar Returns: sqlglot.Expression: optimized expression """ - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs} expression = expression.copy() for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} - expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 5584830..5820851 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() +# SELECTION TO USE IF SELECTION LIST IS EMPTY +DEFAULT_SELECTION = alias("1", "_") + def pushdown_projections(expression): """ @@ -25,7 +28,8 @@ def pushdown_projections(expression): """ # Map of Scope to all columns being selected by outer queries. referenced_columns = defaultdict(set) - + left_union = None + right_union = None # We build the scope tree (which is traversed in DFS postorder), then iterate # over the result in reverse order. This should ensure that the set of selected # columns for a particular scope are completely build by the time we get to it. @@ -37,12 +41,16 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left, right = scope.union_scopes - referenced_columns[left] = parent_selections - referenced_columns[right] = parent_selections + left_union, right_union = scope.union_scopes + referenced_columns[left_union] = parent_selections + referenced_columns[right_union] = parent_selections - if isinstance(scope.expression, exp.Select): - _remove_unused_selections(scope, parent_selections) + if isinstance(scope.expression, exp.Select) and scope != right_union: + removed_indexes = _remove_unused_selections(scope, parent_selections) + # The left union is used for column names to select and if we remove columns from the left + # we need to also remove those same columns in the right that were at the same position + if scope is left_union: + _remove_indexed_selections(right_union, removed_indexes) # Group columns by source name selects = defaultdict(set) @@ -61,6 +69,7 @@ def pushdown_projections(expression): def _remove_unused_selections(scope, parent_selections): + removed_indexes = [] order = scope.expression.args.get("order") if order: @@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections): order_refs = set() new_selections = [] - for selection in scope.selects: + for i, selection in enumerate(scope.selects): if ( SELECT_ALL in parent_selections or selection.alias_or_name in parent_selections or selection.alias_or_name in order_refs ): new_selections.append(selection) + else: + removed_indexes.append(i) # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(alias("1", "_")) + new_selections.append(DEFAULT_SELECTION) + + scope.expression.set("expressions", new_selections) + return removed_indexes + +def _remove_indexed_selections(scope, indexes_to_remove): + new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove] + if not new_selections: + new_selections.append(DEFAULT_SELECTION) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 7d77ef1..36ba028 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -2,8 +2,8 @@ import itertools from sqlglot import alias, exp from sqlglot.errors import OptimizeError -from sqlglot.optimizer.schema import ensure_schema -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema def qualify_columns(expression, schema): @@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables): (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: - if isinstance(derived_table, exp.UDTF): + if isinstance(derived_table.unnest(), exp.UDTF): continue table_alias = derived_table.args.get("alias") if table_alias: @@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver): if column_table: column.set("table", exp.to_identifier(column_table)) + # Determine whether each reference in the order by clause is to a column or an alias. + for ordered in scope.find_all(exp.Ordered): + for column in ordered.find_all(exp.Column): + column_table = column.table + column_name = column.name + + if column_table or column.parent is ordered or column_name not in resolver.all_columns: + continue + + column_table = resolver.get_table(column_name) + + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column_name}") + + column.set("table", exp.to_identifier(column_table)) + def _expand_stars(scope, resolver): """Expand stars to lists of column selections""" @@ -346,6 +362,11 @@ class _Resolver: except Exception as e: raise OptimizeError(str(e)) from e + if isinstance(source, Scope) and isinstance(source.expression, exp.Values): + values_alias = source.expression.parent + if hasattr(values_alias, "alias_column_names"): + return values_alias.alias_column_names + # Otherwise, if referencing another scope, return that scope's named selects return source.expression.named_selects diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 30e93ba..0e467d3 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None): if not source.args.get("catalog"): source.set("catalog", exp.to_identifier(catalog)) - if not isinstance(source.parent, exp.Alias): + if not source.alias: source.replace( alias( source.copy(), diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py deleted file mode 100644 index d7743c9..0000000 --- a/sqlglot/optimizer/schema.py +++ /dev/null @@ -1,180 +0,0 @@ -import abc - -from sqlglot import exp -from sqlglot.errors import OptimizeError -from sqlglot.helper import csv_reader - - -class Schema(abc.ABC): - """Abstract base class for database schemas""" - - @abc.abstractmethod - def column_names(self, table, only_visible=False): - """ - Get the column names for a table. - Args: - table (sqlglot.expressions.Table): Table expression instance - only_visible (bool): Whether to include invisible columns - Returns: - list[str]: list of column names - """ - - @abc.abstractmethod - def get_column_type(self, table, column): - """ - Get the exp.DataType type of a column in the schema. - - Args: - table (sqlglot.expressions.Table): The source table. - column (sqlglot.expressions.Column): The target column. - Returns: - sqlglot.expressions.DataType.Type: The resulting column type. - """ - - -class MappingSchema(Schema): - """ - Schema based on a nested mapping. - - Args: - schema (dict): Mapping in one of the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns - are assumed to be visible. The nesting should mirror that of the schema: - 1. {table: set(*cols)}} - 2. {db: {table: set(*cols)}}} - 3. {catalog: {db: {table: set(*cols)}}}} - dialect (str): The dialect to be used for custom type mappings. - """ - - def __init__(self, schema, visible=None, dialect=None): - self.schema = schema - self.visible = visible - self.dialect = dialect - self._type_mapping_cache = {} - - depth = _dict_depth(schema) - - if not depth: # {} - self.supported_table_args = [] - elif depth == 2: # {table: {col: type}} - self.supported_table_args = ("this",) - elif depth == 3: # {db: {table: {col: type}}} - self.supported_table_args = ("db", "this") - elif depth == 4: # {catalog: {db: {table: {col: type}}}} - self.supported_table_args = ("catalog", "db", "this") - else: - raise OptimizeError(f"Invalid schema shape. Depth: {depth}") - - self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) - - def column_names(self, table, only_visible=False): - if not isinstance(table.this, exp.Identifier): - return fs_get(table) - - args = tuple(table.text(p) for p in self.supported_table_args) - - for forbidden in self.forbidden_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - - columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) - if not only_visible or not self.visible: - return columns - - visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) - return [col for col in columns if col in visible] - - def get_column_type(self, table, column): - try: - schema_type = self.schema.get(table.name, {}).get(column.name).upper() - return self._convert_type(schema_type) - except: - raise OptimizeError(f"Failed to get type for column {column.sql()}") - - def _convert_type(self, schema_type): - """ - Convert a type represented as a string to the corresponding exp.DataType.Type object. - - Args: - schema_type (str): The type we want to convert. - Returns: - sqlglot.expressions.DataType.Type: The resulting expression type. - """ - if schema_type not in self._type_mapping_cache: - try: - self._type_mapping_cache[schema_type] = exp.maybe_parse( - schema_type, into=exp.DataType, dialect=self.dialect - ).this - except AttributeError: - raise OptimizeError(f"Failed to convert type {schema_type}") - - return self._type_mapping_cache[schema_type] - - -def ensure_schema(schema): - if isinstance(schema, Schema): - return schema - - return MappingSchema(schema) - - -def fs_get(table): - name = table.this.name - - if name.upper() == "READ_CSV": - with csv_reader(table) as reader: - return next(reader) - - raise ValueError(f"Cannot read schema for {table}") - - -def _nested_get(d, *path): - """ - Get a value for a nested dictionary. - - Args: - d (dict): dictionary - *path (tuple[str, str]): tuples of (name, key) - `key` is the key in the dictionary to get. - `name` is a string to use in the error if `key` isn't found. - """ - for name, key in path: - d = d.get(key) - if d is None: - name = "table" if name == "this" else name - raise ValueError(f"Unknown {name}") - return d - - -def _dict_depth(d): - """ - Get the nesting depth of a dictionary. - - For example: - >>> _dict_depth(None) - 0 - >>> _dict_depth({}) - 1 - >>> _dict_depth({"a": "b"}) - 1 - >>> _dict_depth({"a": {}}) - 2 - >>> _dict_depth({"a": {"b": {}}}) - 3 - - Args: - d (dict): dictionary - Returns: - int: depth - """ - try: - return 1 + _dict_depth(next(iter(d.values()))) - except AttributeError: - # d doesn't have attribute "values" - return 0 - except StopIteration: - # d.values() returns an empty sequence - return 1 diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 68298a0..b7eb6c2 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -257,12 +257,7 @@ class Scope: referenced_names = [] for table in self.tables: - referenced_names.append( - ( - table.parent.alias if isinstance(table.parent, exp.Alias) else table.name, - table, - ) - ) + referenced_names.append((table.alias_or_name, table)) for derived_table in self.derived_tables: referenced_names.append((derived_table.alias, derived_table.unnest())) @@ -538,8 +533,8 @@ def _add_table_sources(scope): for table in scope.tables: table_name = table.name - if isinstance(table.parent, exp.Alias): - source_name = table.parent.alias + if table.alias: + source_name = table.alias else: source_name = table_name |