summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/__init__.py1
-rw-r--r--sqlglot/optimizer/annotate_types.py2
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py10
-rw-r--r--sqlglot/optimizer/merge_subqueries.py16
-rw-r--r--sqlglot/optimizer/optimizer.py5
-rw-r--r--sqlglot/optimizer/pushdown_projections.py35
-rw-r--r--sqlglot/optimizer/qualify_columns.py27
-rw-r--r--sqlglot/optimizer/qualify_tables.py2
-rw-r--r--sqlglot/optimizer/schema.py180
-rw-r--r--sqlglot/optimizer/scope.py11
11 files changed, 69 insertions, 222 deletions
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index d1146ca..bba0878 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1,2 +1 @@
from sqlglot.optimizer.optimizer import RULES, optimize
-from sqlglot.optimizer.schema import Schema
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index a2cef37..30055bc 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,7 +1,7 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
-from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import ensure_schema
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 44cdc94..e30c263 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
- if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
+ if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index e060739..652cdef 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -12,18 +12,16 @@ def isolate_table_selects(expression):
if not isinstance(source, exp.Table):
continue
- if not isinstance(source.parent, exp.Alias):
+ if not source.alias:
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
- parent = source.parent
-
- parent.replace(
+ source.replace(
exp.select("*")
.from_(
- alias(source, source.name or parent.alias, table=True),
+ alias(source.copy(), source.name or source.alias, table=True),
copy=False,
)
- .subquery(parent.alias, copy=False)
+ .subquery(source.alias, copy=False)
)
return expression
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 3c51c18..70e4629 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False):
inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
- node_to_replace = table
- if isinstance(node_to_replace.parent, exp.Alias):
- node_to_replace = node_to_replace.parent
- alias = node_to_replace.alias
- else:
- alias = table.name
+ alias = table.alias_or_name
_rename_inner_sources(outer_scope, inner_scope, alias)
- _merge_from(outer_scope, inner_scope, node_to_replace, alias)
+ _merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
@@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
if isinstance(source, exp.Subquery):
source.set("alias", exp.TableAlias(this=new_alias))
- elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
- source.parent.set("alias", new_alias)
+ elif isinstance(source, exp.Table) and source.alias:
+ source.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source.copy(), new_alias))
@@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
tables = join_hint.find_all(exp.Table)
for table in tables:
if table.alias_or_name == node_to_replace.alias_or_name:
- new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
- table.set("this", exp.to_identifier(new_table.alias_or_name))
+ table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 2c28ab8..5ad8f46 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,3 +1,4 @@
+import sqlglot
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
+ If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (list): sequence of optimizer rules to use
@@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
Returns:
sqlglot.Expression: optimized expression
"""
- possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
+ possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
expression = expression.copy()
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
-
expression = rule(expression, **rule_kwargs)
return expression
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 5584830..5820851 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
+# SELECTION TO USE IF SELECTION LIST IS EMPTY
+DEFAULT_SELECTION = alias("1", "_")
+
def pushdown_projections(expression):
"""
@@ -25,7 +28,8 @@ def pushdown_projections(expression):
"""
# Map of Scope to all columns being selected by outer queries.
referenced_columns = defaultdict(set)
-
+ left_union = None
+ right_union = None
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
@@ -37,12 +41,16 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
- left, right = scope.union_scopes
- referenced_columns[left] = parent_selections
- referenced_columns[right] = parent_selections
+ left_union, right_union = scope.union_scopes
+ referenced_columns[left_union] = parent_selections
+ referenced_columns[right_union] = parent_selections
- if isinstance(scope.expression, exp.Select):
- _remove_unused_selections(scope, parent_selections)
+ if isinstance(scope.expression, exp.Select) and scope != right_union:
+ removed_indexes = _remove_unused_selections(scope, parent_selections)
+ # The left union is used for column names to select and if we remove columns from the left
+ # we need to also remove those same columns in the right that were at the same position
+ if scope is left_union:
+ _remove_indexed_selections(right_union, removed_indexes)
# Group columns by source name
selects = defaultdict(set)
@@ -61,6 +69,7 @@ def pushdown_projections(expression):
def _remove_unused_selections(scope, parent_selections):
+ removed_indexes = []
order = scope.expression.args.get("order")
if order:
@@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
- for selection in scope.selects:
+ for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
or selection.alias_or_name in parent_selections
or selection.alias_or_name in order_refs
):
new_selections.append(selection)
+ else:
+ removed_indexes.append(i)
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(alias("1", "_"))
+ new_selections.append(DEFAULT_SELECTION)
+
+ scope.expression.set("expressions", new_selections)
+ return removed_indexes
+
+def _remove_indexed_selections(scope, indexes_to_remove):
+ new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
+ if not new_selections:
+ new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 7d77ef1..36ba028 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -2,8 +2,8 @@ import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
-from sqlglot.optimizer.schema import ensure_schema
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import ensure_schema
def qualify_columns(expression, schema):
@@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
- if isinstance(derived_table, exp.UDTF):
+ if isinstance(derived_table.unnest(), exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver):
if column_table:
column.set("table", exp.to_identifier(column_table))
+ # Determine whether each reference in the order by clause is to a column or an alias.
+ for ordered in scope.find_all(exp.Ordered):
+ for column in ordered.find_all(exp.Column):
+ column_table = column.table
+ column_name = column.name
+
+ if column_table or column.parent is ordered or column_name not in resolver.all_columns:
+ continue
+
+ column_table = resolver.get_table(column_name)
+
+ if column_table is None:
+ raise OptimizeError(f"Ambiguous column: {column_name}")
+
+ column.set("table", exp.to_identifier(column_table))
+
def _expand_stars(scope, resolver):
"""Expand stars to lists of column selections"""
@@ -346,6 +362,11 @@ class _Resolver:
except Exception as e:
raise OptimizeError(str(e)) from e
+ if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+ values_alias = source.expression.parent
+ if hasattr(values_alias, "alias_column_names"):
+ return values_alias.alias_column_names
+
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 30e93ba..0e467d3 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None):
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
- if not isinstance(source.parent, exp.Alias):
+ if not source.alias:
source.replace(
alias(
source.copy(),
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
deleted file mode 100644
index d7743c9..0000000
--- a/sqlglot/optimizer/schema.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import abc
-
-from sqlglot import exp
-from sqlglot.errors import OptimizeError
-from sqlglot.helper import csv_reader
-
-
-class Schema(abc.ABC):
- """Abstract base class for database schemas"""
-
- @abc.abstractmethod
- def column_names(self, table, only_visible=False):
- """
- Get the column names for a table.
- Args:
- table (sqlglot.expressions.Table): Table expression instance
- only_visible (bool): Whether to include invisible columns
- Returns:
- list[str]: list of column names
- """
-
- @abc.abstractmethod
- def get_column_type(self, table, column):
- """
- Get the exp.DataType type of a column in the schema.
-
- Args:
- table (sqlglot.expressions.Table): The source table.
- column (sqlglot.expressions.Column): The target column.
- Returns:
- sqlglot.expressions.DataType.Type: The resulting column type.
- """
-
-
-class MappingSchema(Schema):
- """
- Schema based on a nested mapping.
-
- Args:
- schema (dict): Mapping in one of the following forms:
- 1. {table: {col: type}}
- 2. {db: {table: {col: type}}}
- 3. {catalog: {db: {table: {col: type}}}}
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
- are assumed to be visible. The nesting should mirror that of the schema:
- 1. {table: set(*cols)}}
- 2. {db: {table: set(*cols)}}}
- 3. {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
- """
-
- def __init__(self, schema, visible=None, dialect=None):
- self.schema = schema
- self.visible = visible
- self.dialect = dialect
- self._type_mapping_cache = {}
-
- depth = _dict_depth(schema)
-
- if not depth: # {}
- self.supported_table_args = []
- elif depth == 2: # {table: {col: type}}
- self.supported_table_args = ("this",)
- elif depth == 3: # {db: {table: {col: type}}}
- self.supported_table_args = ("db", "this")
- elif depth == 4: # {catalog: {db: {table: {col: type}}}}
- self.supported_table_args = ("catalog", "db", "this")
- else:
- raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
-
- self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
-
- def column_names(self, table, only_visible=False):
- if not isinstance(table.this, exp.Identifier):
- return fs_get(table)
-
- args = tuple(table.text(p) for p in self.supported_table_args)
-
- for forbidden in self.forbidden_args:
- if table.text(forbidden):
- raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
-
- columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
- if not only_visible or not self.visible:
- return columns
-
- visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
- return [col for col in columns if col in visible]
-
- def get_column_type(self, table, column):
- try:
- schema_type = self.schema.get(table.name, {}).get(column.name).upper()
- return self._convert_type(schema_type)
- except:
- raise OptimizeError(f"Failed to get type for column {column.sql()}")
-
- def _convert_type(self, schema_type):
- """
- Convert a type represented as a string to the corresponding exp.DataType.Type object.
-
- Args:
- schema_type (str): The type we want to convert.
- Returns:
- sqlglot.expressions.DataType.Type: The resulting expression type.
- """
- if schema_type not in self._type_mapping_cache:
- try:
- self._type_mapping_cache[schema_type] = exp.maybe_parse(
- schema_type, into=exp.DataType, dialect=self.dialect
- ).this
- except AttributeError:
- raise OptimizeError(f"Failed to convert type {schema_type}")
-
- return self._type_mapping_cache[schema_type]
-
-
-def ensure_schema(schema):
- if isinstance(schema, Schema):
- return schema
-
- return MappingSchema(schema)
-
-
-def fs_get(table):
- name = table.this.name
-
- if name.upper() == "READ_CSV":
- with csv_reader(table) as reader:
- return next(reader)
-
- raise ValueError(f"Cannot read schema for {table}")
-
-
-def _nested_get(d, *path):
- """
- Get a value for a nested dictionary.
-
- Args:
- d (dict): dictionary
- *path (tuple[str, str]): tuples of (name, key)
- `key` is the key in the dictionary to get.
- `name` is a string to use in the error if `key` isn't found.
- """
- for name, key in path:
- d = d.get(key)
- if d is None:
- name = "table" if name == "this" else name
- raise ValueError(f"Unknown {name}")
- return d
-
-
-def _dict_depth(d):
- """
- Get the nesting depth of a dictionary.
-
- For example:
- >>> _dict_depth(None)
- 0
- >>> _dict_depth({})
- 1
- >>> _dict_depth({"a": "b"})
- 1
- >>> _dict_depth({"a": {}})
- 2
- >>> _dict_depth({"a": {"b": {}}})
- 3
-
- Args:
- d (dict): dictionary
- Returns:
- int: depth
- """
- try:
- return 1 + _dict_depth(next(iter(d.values())))
- except AttributeError:
- # d doesn't have attribute "values"
- return 0
- except StopIteration:
- # d.values() returns an empty sequence
- return 1
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 68298a0..b7eb6c2 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -257,12 +257,7 @@ class Scope:
referenced_names = []
for table in self.tables:
- referenced_names.append(
- (
- table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
- table,
- )
- )
+ referenced_names.append((table.alias_or_name, table))
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))
@@ -538,8 +533,8 @@ def _add_table_sources(scope):
for table in scope.tables:
table_name = table.name
- if isinstance(table.parent, exp.Alias):
- source_name = table.parent.alias
+ if table.alias:
+ source_name = table.alias
else:
source_name = table_name