From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/__init__.py | 2 + sqlglot/optimizer/eliminate_subqueries.py | 48 +++ sqlglot/optimizer/expand_multi_table_selects.py | 16 + sqlglot/optimizer/isolate_table_selects.py | 31 ++ sqlglot/optimizer/normalize.py | 136 ++++++++ sqlglot/optimizer/optimize_joins.py | 75 ++++ sqlglot/optimizer/optimizer.py | 43 +++ sqlglot/optimizer/pushdown_predicates.py | 176 ++++++++++ sqlglot/optimizer/pushdown_projections.py | 85 +++++ sqlglot/optimizer/qualify_columns.py | 422 +++++++++++++++++++++++ sqlglot/optimizer/qualify_tables.py | 54 +++ sqlglot/optimizer/quote_identities.py | 25 ++ sqlglot/optimizer/schema.py | 129 +++++++ sqlglot/optimizer/scope.py | 438 ++++++++++++++++++++++++ sqlglot/optimizer/simplify.py | 383 +++++++++++++++++++++ sqlglot/optimizer/unnest_subqueries.py | 220 ++++++++++++ 16 files changed, 2283 insertions(+) create mode 100644 sqlglot/optimizer/__init__.py create mode 100644 sqlglot/optimizer/eliminate_subqueries.py create mode 100644 sqlglot/optimizer/expand_multi_table_selects.py create mode 100644 sqlglot/optimizer/isolate_table_selects.py create mode 100644 sqlglot/optimizer/normalize.py create mode 100644 sqlglot/optimizer/optimize_joins.py create mode 100644 sqlglot/optimizer/optimizer.py create mode 100644 sqlglot/optimizer/pushdown_predicates.py create mode 100644 sqlglot/optimizer/pushdown_projections.py create mode 100644 sqlglot/optimizer/qualify_columns.py create mode 100644 sqlglot/optimizer/qualify_tables.py create mode 100644 sqlglot/optimizer/quote_identities.py create mode 100644 sqlglot/optimizer/schema.py create mode 100644 sqlglot/optimizer/scope.py create mode 100644 sqlglot/optimizer/simplify.py create mode 100644 sqlglot/optimizer/unnest_subqueries.py (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py new file mode 100644 index 0000000..a4c4cc2 --- /dev/null +++ b/sqlglot/optimizer/__init__.py @@ -0,0 +1,2 @@ +from sqlglot.optimizer.optimizer import optimize +from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py new file mode 100644 index 0000000..4bfb733 --- /dev/null +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -0,0 +1,48 @@ +import itertools + +from sqlglot import alias, exp, select, table +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def eliminate_subqueries(expression): + """ + Rewrite duplicate subqueries from sqlglot AST. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y") + >>> eliminate_subqueries(expression).sql() + 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + expression = simplify(expression) + queries = {} + + for scope in traverse_scope(expression): + query = scope.expression + queries[query] = queries.get(query, []) + [query] + + sequence = itertools.count() + + for query, duplicates in queries.items(): + if len(duplicates) == 1: + continue + + alias_ = f"_e_{next(sequence)}" + + for dup in duplicates: + parent = dup.parent + if isinstance(parent, exp.Subquery): + parent.replace(alias(table(alias_), parent.alias_or_name, table=True)) + elif isinstance(parent, exp.Union): + dup.replace(select("*").from_(alias_)) + + expression.with_(alias_, as_=query, copy=False) + + return expression diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py new file mode 100644 index 0000000..ba562df --- /dev/null +++ b/sqlglot/optimizer/expand_multi_table_selects.py @@ -0,0 +1,16 @@ +from sqlglot import exp + + +def expand_multi_table_selects(expression): + for from_ in expression.find_all(exp.From): + parent = from_.parent + + for query in from_.expressions[1:]: + parent.join( + query, + join_type="CROSS", + copy=False, + ) + from_.expressions.remove(query) + + return expression diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py new file mode 100644 index 0000000..c2e021e --- /dev/null +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -0,0 +1,31 @@ +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.scope import traverse_scope + + +def isolate_table_selects(expression): + for scope in traverse_scope(expression): + if len(scope.selected_sources) == 1: + continue + + for (_, source) in scope.selected_sources.values(): + if not isinstance(source, exp.Table): + continue + + if not isinstance(source.parent, exp.Alias): + raise OptimizeError( + "Tables require an alias. Run qualify_tables optimization." + ) + + parent = source.parent + + parent.replace( + exp.select("*") + .from_( + alias(source, source.name or parent.alias, table=True), + copy=False, + ) + .subquery(parent.alias, copy=False) + ) + + return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py new file mode 100644 index 0000000..2c9f89c --- /dev/null +++ b/sqlglot/optimizer/normalize.py @@ -0,0 +1,136 @@ +from sqlglot import exp +from sqlglot.helper import while_changing +from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort + + +def normalize(expression, dnf=False, max_distance=128): + """ + Rewrite sqlglot AST into conjunctive normal form. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(x AND y) OR z") + >>> normalize(expression).sql() + '(x OR z) AND (y OR z)' + + Args: + expression (sqlglot.Expression): expression to normalize + dnf (bool): rewrite in disjunctive normal form instead + max_distance (int): the maximal estimated distance from cnf to attempt conversion + Returns: + sqlglot.Expression: normalized expression + """ + expression = simplify(expression) + + expression = while_changing( + expression, lambda e: distributive_law(e, dnf, max_distance) + ) + return simplify(expression) + + +def normalized(expression, dnf=False): + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + + return not any( + connector.find_ancestor(ancestor) for connector in expression.find_all(root) + ) + + +def normalization_distance(expression, dnf=False): + """ + The difference in the number of predicates between the current expression and the normalized form. + + This is used as an estimate of the cost of the conversion which is exponential in complexity. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") + >>> normalization_distance(expression) + 4 + + Args: + expression (sqlglot.Expression): expression to compute distance + dnf (bool): compute to dnf distance instead + Returns: + int: difference + """ + return sum(_predicate_lengths(expression, dnf)) - ( + len(list(expression.find_all(exp.Connector))) + 1 + ) + + +def _predicate_lengths(expression, dnf): + """ + Returns a list of predicate lengths when expanded to normalized form. + + (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). + """ + expression = expression.unnest() + + if not isinstance(expression, exp.Connector): + return [1] + + left, right = expression.args.values() + + if isinstance(expression, exp.And if dnf else exp.Or): + x = [ + a + b + for a in _predicate_lengths(left, dnf) + for b in _predicate_lengths(right, dnf) + ] + return x + return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) + + +def distributive_law(expression, dnf, max_distance): + """ + x OR (y AND z) -> (x OR y) AND (x OR z) + (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) + """ + if isinstance(expression.unnest(), exp.Connector): + if normalization_distance(expression, dnf) > max_distance: + return expression + + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + + if isinstance(expression, from_exp): + a, b = expression.unnest_operands() + + from_func = exp.and_ if from_exp == exp.And else exp.or_ + to_func = exp.and_ if to_exp == exp.And else exp.or_ + + if isinstance(a, to_exp) and isinstance(b, to_exp): + if len(tuple(a.find_all(exp.Connector))) > len( + tuple(b.find_all(exp.Connector)) + ): + return _distribute(a, b, from_func, to_func) + return _distribute(b, a, from_func, to_func) + if isinstance(a, to_exp): + return _distribute(b, a, from_func, to_func) + if isinstance(b, to_exp): + return _distribute(a, b, from_func, to_func) + + return expression + + +def _distribute(a, b, from_func, to_func): + if isinstance(a, exp.Connector): + exp.replace_children( + a, + lambda c: to_func( + exp.paren(from_func(c, b.left)), + exp.paren(from_func(c, b.right)), + ), + ) + else: + a = to_func(from_func(a, b.left), from_func(a, b.right)) + + return _simplify(a) + + +def _simplify(node): + node = uniq_sort(flatten(node)) + exp.replace_children(node, _simplify) + return node diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py new file mode 100644 index 0000000..40e4ab1 --- /dev/null +++ b/sqlglot/optimizer/optimize_joins.py @@ -0,0 +1,75 @@ +from sqlglot import exp +from sqlglot.helper import tsort +from sqlglot.optimizer.simplify import simplify + + +def optimize_joins(expression): + """ + Removes cross joins if possible and reorder joins based on predicate dependencies. + """ + for select in expression.find_all(exp.Select): + references = {} + cross_joins = [] + + for join in select.args.get("joins", []): + name = join.this.alias_or_name + tables = other_table_names(join, name) + + if tables: + for table in tables: + references[table] = references.get(table, []) + [join] + else: + cross_joins.append((name, join)) + + for name, join in cross_joins: + for dep in references.get(name, []): + on = dep.args["on"] + on = on.replace(simplify(on)) + + if isinstance(on, exp.Connector): + for predicate in on.flatten(): + if name in exp.column_table_names(predicate): + predicate.replace(exp.TRUE) + join.on(predicate, copy=False) + + expression = reorder_joins(expression) + expression = normalize(expression) + return expression + + +def reorder_joins(expression): + """ + Reorder joins by topological sort order based on predicate references. + """ + for from_ in expression.find_all(exp.From): + head = from_.expressions[0] + parent = from_.parent + joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {head.alias_or_name: []} + + for name, join in joins.items(): + dag[name] = other_table_names(join, name) + + parent.set( + "joins", + [joins[name] for name in tsort(dag) if name != head.alias_or_name], + ) + return expression + + +def normalize(expression): + """ + Remove INNER and OUTER from joins as they are optional. + """ + for join in expression.find_all(exp.Join): + if join.kind != "CROSS": + join.set("kind", None) + return expression + + +def other_table_names(join, exclude): + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py new file mode 100644 index 0000000..c03fe3c --- /dev/null +++ b/sqlglot/optimizer/optimizer.py @@ -0,0 +1,43 @@ +from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries +from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects +from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.normalize import normalize +from sqlglot.optimizer.optimize_joins import optimize_joins +from sqlglot.optimizer.pushdown_predicates import pushdown_predicates +from sqlglot.optimizer.pushdown_projections import pushdown_projections +from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.quote_identities import quote_identities +from sqlglot.optimizer.unnest_subqueries import unnest_subqueries + + +def optimize(expression, schema=None, db=None, catalog=None): + """ + Rewrite a sqlglot AST into an optimized form. + + Args: + expression (sqlglot.Expression): expression to optimize + schema (dict|sqlglot.optimizer.Schema): database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + db (str): specify the default database, as might be set by a `USE DATABASE db` statement + catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement + Returns: + sqlglot.Expression: optimized expression + """ + expression = expression.copy() + expression = qualify_tables(expression, db=db, catalog=catalog) + expression = isolate_table_selects(expression) + expression = qualify_columns(expression, schema) + expression = pushdown_projections(expression) + expression = normalize(expression) + expression = unnest_subqueries(expression) + expression = expand_multi_table_selects(expression) + expression = pushdown_predicates(expression) + expression = optimize_joins(expression) + expression = eliminate_subqueries(expression) + expression = quote_identities(expression) + return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py new file mode 100644 index 0000000..e757322 --- /dev/null +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -0,0 +1,176 @@ +from sqlglot import exp +from sqlglot.optimizer.normalize import normalized +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def pushdown_predicates(expression): + """ + Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS + + Example: + >>> import sqlglot + >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_predicates(expression).sql() + 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in reversed(traverse_scope(expression)): + select = scope.expression + where = select.args.get("where") + if where: + pushdown(where.this, scope.selected_sources) + + # joins should only pushdown into itself, not to other joins + # so we limit the selected sources to only itself + for join in select.args.get("joins") or []: + name = join.this.alias_or_name + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + + return expression + + +def pushdown(condition, sources): + if not condition: + return + + condition = condition.replace(simplify(condition)) + cnf_like = normalized(condition) or not normalized(condition, dnf=True) + + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) + + if cnf_like: + pushdown_cnf(predicates, sources) + else: + pushdown_dnf(predicates, sources) + + +def pushdown_cnf(predicates, scope): + """ + If the predicates are in CNF like form, we can simply replace each block in the parent. + """ + for predicate in predicates: + for node in nodes_for_predicate(predicate, scope).values(): + if isinstance(node, exp.Join): + predicate.replace(exp.TRUE) + node.on(predicate, copy=False) + break + if isinstance(node, exp.Select): + predicate.replace(exp.TRUE) + node.where(replace_aliases(node, predicate), copy=False) + + +def pushdown_dnf(predicates, scope): + """ + If the predicates are in DNF form, we can only push down conditions that are in all blocks. + Additionally, we can't remove predicates from their original form. + """ + # find all the tables that can be pushdown too + # these are tables that are referenced in all blocks of a DNF + # (a.x AND b.x) OR (a.y AND c.y) + # only table a can be push down + pushdown_tables = set() + + for a in predicates: + a_tables = set(exp.column_table_names(a)) + + for b in predicates: + a_tables &= set(exp.column_table_names(b)) + + pushdown_tables.update(a_tables) + + conditions = {} + + # for every pushdown table, find all related conditions in all predicates + # combine them with ORS + # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) + for table in sorted(pushdown_tables): + for predicate in predicates: + nodes = nodes_for_predicate(predicate, scope) + + if table not in nodes: + continue + + predicate_condition = None + + for column in predicate.find_all(exp.Column): + if column.table == table: + condition = column.find_ancestor(exp.Condition) + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) + + if predicate_condition: + conditions[table] = ( + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition + ) + + for name, node in nodes.items(): + if name not in conditions: + continue + + predicate = conditions[name] + + if isinstance(node, exp.Join): + node.on(predicate, copy=False) + elif isinstance(node, exp.Select): + node.where(replace_aliases(node, predicate), copy=False) + + +def nodes_for_predicate(predicate, sources): + nodes = {} + tables = exp.column_table_names(predicate) + where_condition = isinstance( + predicate.find_ancestor(exp.Join, exp.Where), exp.Where + ) + + for table in tables: + node, source = sources.get(table) or (None, None) + + # if the predicate is in a where statement we can try to push it down + # we want to find the root join or from statement + if node and where_condition: + node = node.find_ancestor(exp.Join, exp.From) + + # a node can reference a CTE which should be push down + if isinstance(node, exp.From) and not isinstance(source, exp.Table): + node = source.expression + + if isinstance(node, exp.Join): + if node.side: + return {} + nodes[table] = node + elif isinstance(node, exp.Select) and len(tables) == 1: + if not node.args.get("group"): + nodes[table] = node + return nodes + + +def replace_aliases(source, predicate): + aliases = {} + + for select in source.selects: + if isinstance(select, exp.Alias): + aliases[select.alias] = select.this + else: + aliases[select.name] = select + + def _replace_alias(column): + if isinstance(column, exp.Column) and column.name in aliases: + return aliases[column.name] + return column + + return predicate.transform(_replace_alias) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py new file mode 100644 index 0000000..097ce04 --- /dev/null +++ b/sqlglot/optimizer/pushdown_projections.py @@ -0,0 +1,85 @@ +from collections import defaultdict + +from sqlglot import alias, exp +from sqlglot.optimizer.scope import Scope, traverse_scope + +# Sentinel value that means an outer query selecting ALL columns +SELECT_ALL = object() + + +def pushdown_projections(expression): + """ + Rewrite sqlglot AST to remove unused columns projections. + + Example: + >>> import sqlglot + >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_projections(expression).sql() + 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + # Map of Scope to all columns being selected by outer queries. + referenced_columns = defaultdict(set) + + # We build the scope tree (which is traversed in DFS postorder), then iterate + # over the result in reverse order. This should ensure that the set of selected + # columns for a particular scope are completely build by the time we get to it. + for scope in reversed(traverse_scope(expression)): + parent_selections = referenced_columns.get(scope, {SELECT_ALL}) + + if scope.expression.args.get("distinct"): + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT + parent_selections = {SELECT_ALL} + + if isinstance(scope.expression, exp.Union): + left, right = scope.union + referenced_columns[left] = parent_selections + referenced_columns[right] = parent_selections + + if isinstance(scope.expression, exp.Select): + _remove_unused_selections(scope, parent_selections) + + # Group columns by source name + selects = defaultdict(set) + for col in scope.columns: + table_name = col.table + col_name = col.name + selects[table_name].add(col_name) + + # Push the selected columns down to the next scope + for name, (_, source) in scope.selected_sources.items(): + if isinstance(source, Scope): + columns = selects.get(name) or set() + referenced_columns[source].update(columns) + + return expression + + +def _remove_unused_selections(scope, parent_selections): + order = scope.expression.args.get("order") + + if order: + # Assume columns without a qualified table are references to output columns + order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} + else: + order_refs = set() + + new_selections = [] + for selection in scope.selects: + if ( + SELECT_ALL in parent_selections + or selection.alias_or_name in parent_selections + or selection.alias_or_name in order_refs + ): + new_selections.append(selection) + + # If there are no remaining selections, just select a single constant + if not new_selections: + new_selections.append(alias("1", "_")) + + scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py new file mode 100644 index 0000000..394f49e --- /dev/null +++ b/sqlglot/optimizer/qualify_columns.py @@ -0,0 +1,422 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import traverse_scope + +SKIP_QUALIFY = (exp.Unnest, exp.Lateral) + + +def qualify_columns(expression, schema): + """ + Rewrite sqlglot AST to have fully qualified columns. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify_columns(expression, schema).sql() + 'SELECT tbl.col AS col FROM tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + schema = ensure_schema(schema) + + for scope in traverse_scope(expression): + resolver = _Resolver(scope, schema) + _pop_table_column_aliases(scope.ctes) + _pop_table_column_aliases(scope.derived_tables) + _expand_using(scope, resolver) + _expand_group_by(scope, resolver) + _expand_order_by(scope) + _qualify_columns(scope, resolver) + if not isinstance(scope.expression, SKIP_QUALIFY): + _expand_stars(scope, resolver) + _qualify_outputs(scope) + _check_unknown_tables(scope) + + return expression + + +def _pop_table_column_aliases(derived_tables): + """ + Remove table column aliases. + + (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) + """ + for derived_table in derived_tables: + if isinstance(derived_table, SKIP_QUALIFY): + continue + table_alias = derived_table.args.get("alias") + if table_alias: + table_alias.args.pop("columns", None) + + +def _expand_using(scope, resolver): + joins = list(scope.expression.find_all(exp.Join)) + names = {join.this.alias for join in joins} + ordered = [key for key in scope.selected_sources if key not in names] + + # Mapping of automatically joined column names to source names + column_tables = {} + + for join in joins: + using = join.args.get("using") + + if not using: + continue + + join_table = join.this.alias_or_name + + columns = {} + + for k in scope.selected_sources: + if k in ordered: + for column in resolver.get_source_columns(k): + if column not in columns: + columns[column] = k + + ordered.append(join_table) + join_columns = resolver.get_source_columns(join_table) + conditions = [] + + for identifier in using: + identifier = identifier.name + table = columns.get(identifier) + + if not table or identifier not in join_columns: + raise OptimizeError(f"Cannot automatically join: {identifier}") + + conditions.append( + exp.condition( + exp.EQ( + this=exp.column(identifier, table=table), + expression=exp.column(identifier, table=join_table), + ) + ) + ) + + tables = column_tables.setdefault(identifier, []) + if table not in tables: + tables.append(table) + if join_table not in tables: + tables.append(join_table) + + join.args.pop("using") + join.set("on", exp.and_(*conditions)) + + if column_tables: + for column in scope.columns: + if not column.table and column.name in column_tables: + tables = column_tables[column.name] + coalesce = [exp.column(column.name, table=table) for table in tables] + replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) + + # Ensure selects keep their output name + if isinstance(column.parent, exp.Select): + replacement = exp.alias_(replacement, alias=column.name) + + scope.replace(column, replacement) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + + # Replace references to select aliases + def transform(node, *_): + if isinstance(node, exp.Column) and not node.table: + table = resolver.get_table(node.name) + + # Source columns get priority over select aliases + if table: + node.set("table", exp.to_identifier(table)) + return node + + selects = {s.alias_or_name: s for s in scope.selects} + + select = selects.get(node.name) + if select: + scope.clear_cache() + if isinstance(select, exp.Alias): + select = select.this + return select.copy() + + return node + + group.transform(transform, copy=False) + group.set("expressions", _expand_positional_references(scope, group.expressions)) + scope.expression.set("group", group) + + +def _expand_order_by(scope): + order = scope.expression.args.get("order") + if not order: + return + + ordereds = order.expressions + for ordered, new_expression in zip( + ordereds, + _expand_positional_references(scope, (o.this for o in ordereds)), + ): + ordered.set("this", new_expression) + + +def _expand_positional_references(scope, expressions): + new_nodes = [] + for node in expressions: + if node.is_int: + try: + select = scope.selects[int(node.name) - 1] + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + if isinstance(select, exp.Alias): + select = select.this + new_nodes.append(select.copy()) + scope.clear_cache() + else: + new_nodes.append(node) + + return new_nodes + + +def _qualify_columns(scope, resolver): + """Disambiguate columns, ensuring each column specifies a source""" + for column in scope.columns: + column_table = column.table + column_name = column.name + + if ( + column_table + and column_table in scope.sources + and column_name not in resolver.get_source_columns(column_table) + ): + raise OptimizeError(f"Unknown column: {column_name}") + + if not column_table: + column_table = resolver.get_table(column_name) + + if not scope.is_subquery and not scope.is_unnest: + if column_name not in resolver.all_columns: + raise OptimizeError(f"Unknown column: {column_name}") + + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column_name}") + + # column_table can be a '' because bigquery unnest has no table alias + if column_table: + column.set("table", exp.to_identifier(column_table)) + + +def _expand_stars(scope, resolver): + """Expand stars to lists of column selections""" + + new_selections = [] + except_columns = {} + replace_columns = {} + + for expression in scope.selects: + if isinstance(expression, exp.Star): + tables = list(scope.selected_sources) + _add_except_columns(expression, tables, except_columns) + _add_replace_columns(expression, tables, replace_columns) + elif isinstance(expression, exp.Column) and isinstance( + expression.this, exp.Star + ): + tables = [expression.table] + _add_except_columns(expression.this, tables, except_columns) + _add_replace_columns(expression.this, tables, replace_columns) + else: + new_selections.append(expression) + continue + + for table in tables: + if table not in scope.sources: + raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table) + table_id = id(table) + for name in columns: + if name not in except_columns.get(table_id, set()): + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table) + new_selections.append( + alias(column, alias_) if alias_ != name else column + ) + + scope.expression.set("expressions", new_selections) + + +def _add_except_columns(expression, tables, except_columns): + except_ = expression.args.get("except") + + if not except_: + return + + columns = {e.name for e in except_} + + for table in tables: + except_columns[id(table)] = columns + + +def _add_replace_columns(expression, tables, replace_columns): + replace = expression.args.get("replace") + + if not replace: + return + + columns = {e.this.name: e.alias for e in replace} + + for table in tables: + replace_columns[id(table)] = columns + + +def _qualify_outputs(scope): + """Ensure all output columns are aliased""" + new_selections = [] + + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): + if isinstance(selection, exp.Column): + # convoluted setter because a simple selection.replace(alias) would require a copy + alias_ = alias(exp.column(""), alias=selection.name) + alias_.set("this", selection) + selection = alias_ + elif not isinstance(selection, exp.Alias): + alias_ = alias(exp.column(""), f"_col_{i}") + alias_.set("this", selection) + selection = alias_ + + if aliased_column: + selection.set("alias", exp.to_identifier(aliased_column)) + + new_selections.append(selection) + + scope.expression.set("expressions", new_selections) + + +def _check_unknown_tables(scope): + if ( + scope.external_columns + and not scope.is_unnest + and not scope.is_correlated_subquery + ): + raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") + + +class _Resolver: + """ + Helper for resolving columns. + + This is a class so we can lazily load some things and easily share them across functions. + """ + + def __init__(self, scope, schema): + self.scope = scope + self.schema = schema + self._source_columns = None + self._unambiguous_columns = None + self._all_columns = None + + def get_table(self, column_name): + """ + Get the table for a column name. + + Args: + column_name (str) + Returns: + (str) table name + """ + if self._unambiguous_columns is None: + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) + return self._unambiguous_columns.get(column_name) + + @property + def all_columns(self): + """All available columns of all sources in this scope""" + if self._all_columns is None: + self._all_columns = set( + column + for columns in self._get_all_source_columns().values() + for column in columns + ) + return self._all_columns + + def get_source_columns(self, name): + """Resolve the source columns for a given source `name`""" + if name not in self.scope.sources: + raise OptimizeError(f"Unknown table: {name}") + + source = self.scope.sources[name] + + # If referencing a table, return the columns from the schema + if isinstance(source, exp.Table): + try: + return self.schema.column_names(source) + except Exception as e: + raise OptimizeError(str(e)) from e + + # Otherwise, if referencing another scope, return that scope's named selects + return source.expression.named_selects + + def _get_all_source_columns(self): + if self._source_columns is None: + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } + return self._source_columns + + def _get_unambiguous_columns(self, source_columns): + """ + Find all the unambiguous columns in sources. + + Args: + source_columns (dict): Mapping of names to source columns + Returns: + dict: Mapping of column name to source name + """ + if not source_columns: + return {} + + source_columns = list(source_columns.items()) + + first_table, first_columns = source_columns[0] + unambiguous_columns = { + col: first_table for col in self._find_unique_columns(first_columns) + } + all_columns = set(unambiguous_columns) + + for table, columns in source_columns[1:]: + unique = self._find_unique_columns(columns) + ambiguous = set(all_columns).intersection(unique) + all_columns.update(columns) + for column in ambiguous: + unambiguous_columns.pop(column, None) + for column in unique.difference(ambiguous): + unambiguous_columns[column] = table + + return unambiguous_columns + + @staticmethod + def _find_unique_columns(columns): + """ + Find the unique columns in a list of columns. + + Example: + >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) + ['a', 'c'] + + This is necessary because duplicate column names are ambiguous. + """ + counts = {} + for column in columns: + counts[column] = counts.get(column, 0) + 1 + return {column for column, count in counts.items() if count == 1} diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py new file mode 100644 index 0000000..9f8b9f5 --- /dev/null +++ b/sqlglot/optimizer/qualify_tables.py @@ -0,0 +1,54 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.optimizer.scope import traverse_scope + + +def qualify_tables(expression, db=None, catalog=None): + """ + Rewrite sqlglot AST to have fully qualified tables. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") + >>> qualify_tables(expression, db="db").sql() + 'SELECT 1 FROM db.tbl AS tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + db (str): Database name + catalog (str): Catalog name + Returns: + sqlglot.Expression: qualified expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + for derived_table in scope.ctes + scope.derived_tables: + if not derived_table.args.get("alias"): + alias_ = f"_q_{next(sequence)}" + derived_table.set( + "alias", exp.TableAlias(this=exp.to_identifier(alias_)) + ) + scope.rename_source(None, alias_) + + for source in scope.sources.values(): + if isinstance(source, exp.Table): + identifier = isinstance(source.this, exp.Identifier) + + if identifier: + if not source.args.get("db"): + source.set("db", exp.to_identifier(db)) + if not source.args.get("catalog"): + source.set("catalog", exp.to_identifier(catalog)) + + if not isinstance(source.parent, exp.Alias): + source.replace( + alias( + source.copy(), + source.this if identifier else f"_q_{next(sequence)}", + table=True, + ) + ) + + return expression diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py new file mode 100644 index 0000000..17623cc --- /dev/null +++ b/sqlglot/optimizer/quote_identities.py @@ -0,0 +1,25 @@ +from sqlglot import exp + + +def quote_identities(expression): + """ + Rewrite sqlglot AST to ensure all identities are quoted. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x") + >>> quote_identities(expression).sql() + 'SELECT "x"."a" AS "a" FROM "db"."x"' + + Args: + expression (sqlglot.Expression): expression to quote + Returns: + sqlglot.Expression: quoted expression + """ + + def qualify(node): + if isinstance(node, exp.Identifier): + node.set("quoted", True) + return node + + return expression.transform(qualify, copy=False) diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py new file mode 100644 index 0000000..9968108 --- /dev/null +++ b/sqlglot/optimizer/schema.py @@ -0,0 +1,129 @@ +import abc + +from sqlglot import exp +from sqlglot.errors import OptimizeError +from sqlglot.helper import csv_reader + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @abc.abstractmethod + def column_names(self, table): + """ + Get the column names for a table. + + Args: + table (sqlglot.expressions.Table): Table expression instance + Returns: + list[str]: list of column names + """ + + +class MappingSchema(Schema): + """ + Schema based on a nested mapping. + + Args: + schema (dict): Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + """ + + def __init__(self, schema): + self.schema = schema + + depth = _dict_depth(schema) + + if not depth: # {} + self.supported_table_args = [] + elif depth == 2: # {table: {col: type}} + self.supported_table_args = ("this",) + elif depth == 3: # {db: {table: {col: type}}} + self.supported_table_args = ("db", "this") + elif depth == 4: # {catalog: {db: {table: {col: type}}}} + self.supported_table_args = ("catalog", "db", "this") + else: + raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + + self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) + + def column_names(self, table): + if not isinstance(table.this, exp.Identifier): + return fs_get(table) + + args = tuple(table.text(p) for p in self.supported_table_args) + + for forbidden in self.forbidden_args: + if table.text(forbidden): + raise ValueError( + f"Schema doesn't support {forbidden}. Received: {table.sql()}" + ) + return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + +def ensure_schema(schema): + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema) + + +def fs_get(table): + name = table.this.name.upper() + + if name.upper() == "READ_CSV": + with csv_reader(table) as reader: + return next(reader) + + raise ValueError(f"Cannot read schema for {table}") + + +def _nested_get(d, *path): + """ + Get a value for a nested dictionary. + + Args: + d (dict): dictionary + *path (tuple[str, str]): tuples of (name, key) + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + """ + for name, key in path: + d = d.get(key) + if d is None: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}") + return d + + +def _dict_depth(d): + """ + Get the nesting depth of a dictionary. + + For example: + >>> _dict_depth(None) + 0 + >>> _dict_depth({}) + 1 + >>> _dict_depth({"a": "b"}) + 1 + >>> _dict_depth({"a": {}}) + 2 + >>> _dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + _dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py new file mode 100644 index 0000000..f6f59e8 --- /dev/null +++ b/sqlglot/optimizer/scope.py @@ -0,0 +1,438 @@ +from copy import copy +from enum import Enum, auto + +from sqlglot import exp +from sqlglot.errors import OptimizeError + + +class ScopeType(Enum): + ROOT = auto() + SUBQUERY = auto() + DERIVED_TABLE = auto() + CTE = auto() + UNION = auto() + UNNEST = auto() + + +class Scope: + """ + Selection scope. + + Attributes: + expression (exp.Select|exp.Union): Root expression of this scope + sources (dict[str, exp.Table|Scope]): Mapping of source name to either + a Table expression or another Scope instance. For example: + SELECT * FROM x {"x": Table(this="x")} + SELECT * FROM x AS y {"y": Table(this="x")} + SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + outer_column_list (list[str]): If this is a derived table or CTE, and the outer query + defines a column list of it's alias of this scope, this is that list of columns. + For example: + SELECT * FROM (SELECT ...) AS y(col1, col2) + The inner query would have `["col1", "col2"]` for its `outer_column_list` + parent (Scope): Parent scope + scope_type (ScopeType): Type of this scope, relative to it's parent + subquery_scopes (list[Scope]): List of all child scopes for subqueries. + This does not include derived tables or CTEs. + union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be + a tuple of the left and right child scopes. + """ + + def __init__( + self, + expression, + sources=None, + outer_column_list=None, + parent=None, + scope_type=ScopeType.ROOT, + ): + self.expression = expression + self.sources = sources or {} + self.outer_column_list = outer_column_list or [] + self.parent = parent + self.scope_type = scope_type + self.subquery_scopes = [] + self.union = None + self.clear_cache() + + def clear_cache(self): + self._collected = False + self._raw_columns = None + self._derived_tables = None + self._tables = None + self._ctes = None + self._subqueries = None + self._selected_sources = None + self._columns = None + self._external_columns = None + + def branch(self, expression, scope_type, add_sources=None, **kwargs): + """Branch from the current scope to a new, inner scope""" + sources = copy(self.sources) + if add_sources: + sources.update(add_sources) + return Scope( + expression=expression.unnest(), + sources=sources, + parent=self, + scope_type=scope_type, + **kwargs, + ) + + def _collect(self): + self._tables = [] + self._ctes = [] + self._subqueries = [] + self._derived_tables = [] + self._raw_columns = [] + + # We'll use this variable to pass state into the dfs generator. + # Whenever we set it to True, we exclude a subtree from traversal. + prune = False + + for node, parent, _ in self.expression.dfs(prune=lambda *_: prune): + prune = False + + if node is self.expression: + continue + if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + self._raw_columns.append(node) + elif isinstance(node, exp.Table): + self._tables.append(node) + elif isinstance(node, (exp.Unnest, exp.Lateral)): + self._derived_tables.append(node) + elif isinstance(node, exp.CTE): + self._ctes.append(node) + prune = True + elif isinstance(node, exp.Subquery) and isinstance( + parent, (exp.From, exp.Join) + ): + self._derived_tables.append(node) + prune = True + elif isinstance(node, exp.Subqueryable): + self._subqueries.append(node) + prune = True + + self._collected = True + + def _ensure_collected(self): + if not self._collected: + self._collect() + + def replace(self, old, new): + """ + Replace `old` with `new`. + + This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. + + Args: + old (exp.Expression): old node + new (exp.Expression): new node + """ + old.replace(new) + self.clear_cache() + + @property + def tables(self): + """ + List of tables in this scope. + + Returns: + list[exp.Table]: tables + """ + self._ensure_collected() + return self._tables + + @property + def ctes(self): + """ + List of CTEs in this scope. + + Returns: + list[exp.CTE]: ctes + """ + self._ensure_collected() + return self._ctes + + @property + def derived_tables(self): + """ + List of derived tables in this scope. + + For example: + SELECT * FROM (SELECT ...) <- that's a derived table + + Returns: + list[exp.Subquery]: derived tables + """ + self._ensure_collected() + return self._derived_tables + + @property + def subqueries(self): + """ + List of subqueries in this scope. + + For example: + SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery + + Returns: + list[exp.Subqueryable]: subqueries + """ + self._ensure_collected() + return self._subqueries + + @property + def columns(self): + """ + List of columns in this scope. + + Returns: + list[exp.Column]: Column instances in this scope, plus any + Columns that reference this scope from correlated subqueries. + """ + if self._columns is None: + self._ensure_collected() + columns = self._raw_columns + + external_columns = [ + column + for scope in self.subquery_scopes + for column in scope.external_columns + ] + + named_outputs = {e.alias_or_name for e in self.expression.expressions} + + self._columns = [ + c + for c in columns + external_columns + if not ( + c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs + ) + ] + return self._columns + + @property + def selected_sources(self): + """ + Mapping of nodes and sources that are actually selected from in this scope. + + That is, all tables in a schema are selectable at any point. But a + table only becomes a selected source if it's included in a FROM or JOIN clause. + + Returns: + dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes + """ + if self._selected_sources is None: + referenced_names = [] + + for table in self.tables: + referenced_names.append( + ( + table.parent.alias + if isinstance(table.parent, exp.Alias) + else table.name, + table, + ) + ) + for derived_table in self.derived_tables: + referenced_names.append((derived_table.alias, derived_table.unnest())) + + result = {} + + for name, node in referenced_names: + if name in self.sources: + result[name] = (node, self.sources[name]) + + self._selected_sources = result + return self._selected_sources + + @property + def selects(self): + """ + Select expressions of this scope. + + For example, for the following expression: + SELECT 1 as a, 2 as b FROM x + + The outputs are the "1 as a" and "2 as b" expressions. + + Returns: + list[exp.Expression]: expressions + """ + if isinstance(self.expression, exp.Union): + return [] + return self.expression.selects + + @property + def external_columns(self): + """ + Columns that appear to reference sources in outer scopes. + + Returns: + list[exp.Column]: Column instances that don't reference + sources in the current scope. + """ + if self._external_columns is None: + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] + return self._external_columns + + def source_columns(self, source_name): + """ + Get all columns in the current scope for a particular source. + + Args: + source_name (str): Name of the source + Returns: + list[exp.Column]: Column instances that reference `source_name` + """ + return [column for column in self.columns if column.table == source_name] + + @property + def is_subquery(self): + """Determine if this scope is a subquery""" + return self.scope_type == ScopeType.SUBQUERY + + @property + def is_unnest(self): + """Determine if this scope is an unnest""" + return self.scope_type == ScopeType.UNNEST + + @property + def is_correlated_subquery(self): + """Determine if this scope is a correlated subquery""" + return bool(self.is_subquery and self.external_columns) + + def rename_source(self, old_name, new_name): + """Rename a source in this scope""" + columns = self.sources.pop(old_name or "", []) + self.sources[new_name] = columns + + +def traverse_scope(expression): + """ + Traverse an expression by it's "scopes". + + "Scope" represents the current context of a Select statement. + + This is helpful for optimizing queries, where we need more information than + the expression tree itself. For example, we might care about the source + names within a subquery. Returns a list because a generator could result in + incomplete properties which is confusing. + + Examples: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") + >>> scopes = traverse_scope(expression) + >>> scopes[0].expression.sql(), list(scopes[0].sources) + ('SELECT a FROM x', ['x']) + >>> scopes[1].expression.sql(), list(scopes[1].sources) + ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) + + Args: + expression (exp.Expression): expression to traverse + Returns: + List[Scope]: scope instances + """ + return list(_traverse_scope(Scope(expression))) + + +def _traverse_scope(scope): + if isinstance(scope.expression, exp.Select): + yield from _traverse_select(scope) + elif isinstance(scope.expression, exp.Union): + yield from _traverse_union(scope) + elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)): + pass + elif isinstance(scope.expression, exp.Subquery): + yield from _traverse_subqueries(scope) + else: + raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") + yield scope + + +def _traverse_select(scope): + yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) + yield from _traverse_subqueries(scope) + yield from _traverse_derived_tables( + scope.derived_tables, scope, ScopeType.DERIVED_TABLE + ) + _add_table_sources(scope) + + +def _traverse_union(scope): + yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) + + # The last scope to be yield should be the top most scope + left = None + for left in _traverse_scope( + scope.branch(scope.expression.left, scope_type=ScopeType.UNION) + ): + yield left + + right = None + for right in _traverse_scope( + scope.branch(scope.expression.right, scope_type=ScopeType.UNION) + ): + yield right + + scope.union = (left, right) + + +def _traverse_derived_tables(derived_tables, scope, scope_type): + sources = {} + + for derived_table in derived_tables: + for child_scope in _traverse_scope( + scope.branch( + derived_table + if isinstance(derived_table, (exp.Unnest, exp.Lateral)) + else derived_table.this, + add_sources=sources if scope_type == ScopeType.CTE else None, + outer_column_list=derived_table.alias_column_names, + scope_type=ScopeType.UNNEST + if isinstance(derived_table, exp.Unnest) + else scope_type, + ) + ): + yield child_scope + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + sources[derived_table.alias] = child_scope + scope.sources.update(sources) + + +def _add_table_sources(scope): + sources = {} + for table in scope.tables: + table_name = table.name + + if isinstance(table.parent, exp.Alias): + source_name = table.parent.alias + else: + source_name = table_name + + if table_name in scope.sources: + # This is a reference to a parent source (e.g. a CTE), not an actual table. + scope.sources[source_name] = scope.sources[table_name] + elif source_name in scope.sources: + raise OptimizeError(f"Duplicate table name: {source_name}") + else: + sources[source_name] = table + + scope.sources.update(sources) + + +def _traverse_subqueries(scope): + for subquery in scope.subqueries: + top = None + for child_scope in _traverse_scope( + scope.branch(subquery, scope_type=ScopeType.SUBQUERY) + ): + yield child_scope + top = child_scope + scope.subquery_scopes.append(top) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py new file mode 100644 index 0000000..6771153 --- /dev/null +++ b/sqlglot/optimizer/simplify.py @@ -0,0 +1,383 @@ +import datetime +import functools +import itertools +from collections import deque +from decimal import Decimal + +from sqlglot import exp +from sqlglot.expressions import FALSE, NULL, TRUE +from sqlglot.generator import Generator +from sqlglot.helper import while_changing + +GENERATOR = Generator(normalize=True, identify=True) + + +def simplify(expression): + """ + Rewrite sqlglot AST to simplify expressions. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("TRUE AND TRUE") + >>> simplify(expression).sql() + 'TRUE' + + Args: + expression (sqlglot.Expression): expression to simplify + Returns: + sqlglot.Expression: simplified expression + """ + + def _simplify(expression, root=True): + node = expression + node = uniq_sort(node) + node = absorb_and_eliminate(node) + exp.replace_children(node, lambda e: _simplify(e, False)) + node = simplify_not(node) + node = flatten(node) + node = simplify_connectors(node) + node = remove_compliments(node) + node.parent = expression.parent + node = simplify_literals(node) + node = simplify_parens(node) + if root: + expression.replace(node) + return node + + expression = while_changing(expression, _simplify) + remove_where_true(expression) + return expression + + +def simplify_not(expression): + """ + Demorgan's Law + NOT (x OR y) -> NOT x AND NOT y + NOT (x AND y) -> NOT x OR NOT y + """ + if isinstance(expression, exp.Not): + if isinstance(expression.this, exp.Paren): + condition = expression.this.unnest() + if isinstance(condition, exp.And): + return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) + if isinstance(condition, exp.Or): + return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + if always_true(expression.this): + return FALSE + if expression.this == FALSE: + return TRUE + if isinstance(expression.this, exp.Not): + # double negation + # NOT NOT x -> x + return expression.this.this + return expression + + +def flatten(expression): + """ + A AND (B AND C) -> A AND B AND C + A OR (B OR C) -> A OR B OR C + """ + if isinstance(expression, exp.Connector): + for node in expression.args.values(): + child = node.unnest() + if isinstance(child, expression.__class__): + node.replace(child) + return expression + + +def simplify_connectors(expression): + if isinstance(expression, exp.Connector): + left = expression.left + right = expression.right + + if left == right: + return left + + if isinstance(expression, exp.And): + if NULL in (left, right): + return NULL + if FALSE in (left, right): + return FALSE + if always_true(left) and always_true(right): + return TRUE + if always_true(left): + return right + if always_true(right): + return left + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return TRUE + if left == FALSE and right == FALSE: + return FALSE + if ( + (left == NULL and right == NULL) + or (left == NULL and right == FALSE) + or (left == FALSE and right == NULL) + ): + return NULL + if left == FALSE: + return right + if right == FALSE: + return left + return expression + + +def remove_compliments(expression): + """ + Removing compliments. + + A AND NOT A -> FALSE + A OR NOT A -> TRUE + """ + if isinstance(expression, exp.Connector): + compliment = FALSE if isinstance(expression, exp.And) else TRUE + + for a, b in itertools.permutations(expression.flatten(), 2): + if is_complement(a, b): + return compliment + return expression + + +def uniq_sort(expression): + """ + Uniq and sort a connector. + + C AND A AND B AND B -> A AND B AND C + """ + if isinstance(expression, exp.Connector): + result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ + flattened = tuple(expression.flatten()) + deduped = {GENERATOR.generate(e): e for e in flattened} + arr = tuple(deduped.items()) + + # check if the operands are already sorted, if not sort them + # A AND C AND B -> A AND B AND C + for i, (sql, e) in enumerate(arr[1:]): + if sql < arr[i][0]: + expression = result_func(*(deduped[sql] for sql in sorted(deduped))) + break + else: + # we didn't have to sort but maybe we need to dedup + if len(deduped) < len(flattened): + expression = result_func(*deduped.values()) + + return expression + + +def absorb_and_eliminate(expression): + """ + absorption: + A AND (A OR B) -> A + A OR (A AND B) -> A + A AND (NOT A OR B) -> A AND B + A OR (NOT A AND B) -> A OR B + elimination: + (A AND B) OR (A AND NOT B) -> A + (A OR B) AND (A OR NOT B) -> A + """ + if isinstance(expression, exp.Connector): + kind = exp.Or if isinstance(expression, exp.And) else exp.And + + for a, b in itertools.permutations(expression.flatten(), 2): + if isinstance(a, kind): + aa, ab = a.unnest_operands() + + # absorb + if is_complement(b, aa): + aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) + elif is_complement(b, ab): + ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) + elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( + a.flatten() + ): + a.replace(exp.FALSE if kind == exp.And else exp.TRUE) + elif isinstance(b, kind): + # eliminate + rhs = b.unnest_operands() + ba, bb = rhs + + if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): + a.replace(aa) + b.replace(aa) + elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): + a.replace(ab) + b.replace(ab) + + return expression + + +def simplify_literals(expression): + if isinstance(expression, exp.Binary): + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = _simplify_binary(expression, a, b) + + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) + elif isinstance(expression, exp.Neg): + this = expression.this + if this.is_number: + value = this.name + if value[0] == "-": + return exp.Literal.number(value[1:]) + return exp.Literal.number(f"-{value}") + + return expression + + +def _simplify_binary(expression, a, b): + if isinstance(expression, exp.Is): + if isinstance(b, exp.Not): + c = b.this + not_ = True + else: + c = b + not_ = False + + if c == NULL: + if isinstance(a, exp.Literal): + return TRUE if not_ else FALSE + if a == NULL: + return FALSE if not_ else TRUE + elif NULL in (a, b): + return NULL + + if isinstance(expression, exp.EQ) and a == b: + return TRUE + + if a.is_number and b.is_number: + a = int(a.name) if a.is_int else Decimal(a.name) + b = int(b.name) if b.is_int else Decimal(b.name) + + if isinstance(expression, exp.Add): + return exp.Literal.number(a + b) + if isinstance(expression, exp.Sub): + return exp.Literal.number(a - b) + if isinstance(expression, exp.Mul): + return exp.Literal.number(a * b) + if isinstance(expression, exp.Div): + if isinstance(a, int) and isinstance(b, int): + return exp.Literal.number(a // b) + return exp.Literal.number(a / b) + + boolean = eval_boolean(expression, a, b) + + if boolean: + return boolean + elif a.is_string and b.is_string: + boolean = eval_boolean(expression, a, b) + + if boolean: + return boolean + elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): + a, b = extract_date(a), extract_interval(b) + if b: + if isinstance(expression, exp.Add): + return date_literal(a + b) + if isinstance(expression, exp.Sub): + return date_literal(a - b) + elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): + a, b = extract_interval(a), extract_date(b) + # you cannot subtract a date from an interval + if a and isinstance(expression, exp.Add): + return date_literal(a + b) + + return None + + +def simplify_parens(expression): + if ( + isinstance(expression, exp.Paren) + and not isinstance(expression.this, exp.Select) + and ( + not isinstance(expression.parent, (exp.Condition, exp.Binary)) + or isinstance(expression.this, (exp.Is, exp.Like)) + or not isinstance(expression.this, exp.Binary) + ) + ): + return expression.this + return expression + + +def remove_where_true(expression): + for where in expression.find_all(exp.Where): + if always_true(where.this): + where.parent.set("where", None) + for join in expression.find_all(exp.Join): + if always_true(join.args.get("on")): + join.set("kind", "CROSS") + join.set("on", None) + + +def always_true(expression): + return expression == TRUE or isinstance(expression, exp.Literal) + + +def is_complement(a, b): + return isinstance(b, exp.Not) and b.this == a + + +def eval_boolean(expression, a, b): + if isinstance(expression, (exp.EQ, exp.Is)): + return boolean_literal(a == b) + if isinstance(expression, exp.NEQ): + return boolean_literal(a != b) + if isinstance(expression, exp.GT): + return boolean_literal(a > b) + if isinstance(expression, exp.GTE): + return boolean_literal(a >= b) + if isinstance(expression, exp.LT): + return boolean_literal(a < b) + if isinstance(expression, exp.LTE): + return boolean_literal(a <= b) + return None + + +def extract_date(cast): + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + return None + + +def extract_interval(interval): + try: + from dateutil.relativedelta import relativedelta + except ModuleNotFoundError: + return None + + n = int(interval.name) + unit = interval.text("unit").lower() + + if unit == "year": + return relativedelta(years=n) + if unit == "month": + return relativedelta(months=n) + if unit == "week": + return relativedelta(weeks=n) + if unit == "day": + return relativedelta(days=n) + return None + + +def date_literal(date): + return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + + +def boolean_literal(condition): + return TRUE if condition else FALSE diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py new file mode 100644 index 0000000..55c81c5 --- /dev/null +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -0,0 +1,220 @@ +import itertools + +from sqlglot import exp +from sqlglot.optimizer.scope import traverse_scope + + +def unnest_subqueries(expression): + """ + Rewrite sqlglot AST to convert some predicates with subqueries into joins. + + Convert the subquery into a group by so it is not a many to many left join. + Unnesting can only occur if the subquery does not have LIMIT or OFFSET. + Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") + >>> unnest_subqueries(expression).sql() + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ + AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + + Args: + expression (sqlglot.Expression): expression to unnest + Returns: + sqlglot.Expression: unnested expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + select = scope.expression + parent = select.parent_select + if scope.external_columns: + decorrelate(select, parent, scope.external_columns, sequence) + else: + unnest(select, parent, sequence) + + return expression + + +def unnest(select, parent_select, sequence): + predicate = select.find_ancestor(exp.In, exp.Any) + + if not predicate or parent_select is not predicate.parent_select: + return + + if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + return + + if isinstance(predicate, exp.Any): + predicate = predicate.find_ancestor(exp.EQ) + + if not predicate or parent_select is not predicate.parent_select: + return + + column = _other_operand(predicate) + value = select.selects[0] + alias = _alias(sequence) + + on = exp.condition(f'{column} = "{alias}"."{value.alias}"') + _replace(predicate, f"NOT {on.right} IS NULL") + + parent_select.join( + select.group_by(value.this, copy=False), + on=on, + join_type="LEFT", + join_alias=alias, + copy=False, + ) + + +def decorrelate(select, parent_select, external_columns, sequence): + where = select.args.get("where") + + if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): + return + + table_alias = _alias(sequence) + keys = [] + + # for all external columns in the where statement, + # split out the relevant data to convert it into a join + for column in external_columns: + if column.find_ancestor(exp.Where) is not where: + return + + predicate = column.find_ancestor(exp.Predicate) + + if not predicate or predicate.find_ancestor(exp.Where) is not where: + return + + if isinstance(predicate, exp.Binary): + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) + else: + return + + keys.append((key, column, predicate)) + + if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): + return + + value = select.selects[0] + key_aliases = {} + group_by = [] + + for key, _, predicate in keys: + # if we filter on the value of the subquery, it needs to be unique + if key == value.this: + key_aliases[key] = value.alias + group_by.append(key) + else: + if key not in key_aliases: + key_aliases[key] = _alias(sequence) + # all predicates that are equalities must also be in the unique + # so that we don't do a many to many join + if isinstance(predicate, exp.EQ) and key not in group_by: + group_by.append(key) + + parent_predicate = select.find_ancestor(exp.Predicate) + + # if the value of the subquery is not an agg or a key, we need to collect it into an array + # so that it can be grouped + if not value.find(exp.AggFunc) and value.this not in group_by: + select.select( + f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False + ) + + # exists queries should not have any selects as it only checks if there are any rows + # all selects will be added by the optimizer and only used for join keys + if isinstance(parent_predicate, exp.Exists): + select.args["expressions"] = [] + + for key, alias in key_aliases.items(): + if key in group_by: + # add all keys to the projections of the subquery + # so that we can use it as a join key + if isinstance(parent_predicate, exp.Exists) or key != value.this: + select.select(f"{key} AS {alias}", copy=False) + else: + select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + + alias = exp.column(value.alias, table_alias) + other = _other_operand(parent_predicate) + + if isinstance(parent_predicate, exp.Exists): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") + else: + parent_predicate = _replace(parent_predicate, "TRUE") + elif isinstance(parent_predicate, exp.All): + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.Any): + if value.this in group_by: + parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.In): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, + f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", + ) + else: + select.parent.replace(alias) + + for key, column, predicate in keys: + predicate.replace(exp.TRUE) + nested = exp.column(key_aliases[key], table_alias) + + if key in group_by: + key.replace(nested) + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) + elif isinstance(predicate, exp.EQ): + parent_predicate = _replace( + parent_predicate, + f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", + ) + else: + key.replace(exp.to_identifier("_x")) + parent_predicate = _replace( + parent_predicate, + f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', + ) + + parent_select.join( + select.group_by(*group_by, copy=False), + on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], + join_type="LEFT", + join_alias=table_alias, + copy=False, + ) + + +def _alias(sequence): + return f"_u_{next(sequence)}" + + +def _replace(expression, condition): + return expression.replace(exp.condition(condition)) + + +def _other_operand(expression): + if isinstance(expression, exp.In): + return expression.this + + if isinstance(expression, exp.Binary): + return expression.right if expression.arg_key == "this" else expression.left + + return None -- cgit v1.2.3