summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/__init__.py2
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py48
-rw-r--r--sqlglot/optimizer/expand_multi_table_selects.py16
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py31
-rw-r--r--sqlglot/optimizer/normalize.py136
-rw-r--r--sqlglot/optimizer/optimize_joins.py75
-rw-r--r--sqlglot/optimizer/optimizer.py43
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py176
-rw-r--r--sqlglot/optimizer/pushdown_projections.py85
-rw-r--r--sqlglot/optimizer/qualify_columns.py422
-rw-r--r--sqlglot/optimizer/qualify_tables.py54
-rw-r--r--sqlglot/optimizer/quote_identities.py25
-rw-r--r--sqlglot/optimizer/schema.py129
-rw-r--r--sqlglot/optimizer/scope.py438
-rw-r--r--sqlglot/optimizer/simplify.py383
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py220
16 files changed, 2283 insertions, 0 deletions
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
new file mode 100644
index 0000000..a4c4cc2
--- /dev/null
+++ b/sqlglot/optimizer/__init__.py
@@ -0,0 +1,2 @@
+from sqlglot.optimizer.optimizer import optimize
+from sqlglot.optimizer.schema import Schema
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
new file mode 100644
index 0000000..4bfb733
--- /dev/null
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import alias, exp, select, table
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def eliminate_subqueries(expression):
+ """
+ Rewrite duplicate subqueries from sqlglot AST.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
+ >>> eliminate_subqueries(expression).sql()
+ 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ schema (dict|sqlglot.optimizer.Schema): Database schema
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ expression = simplify(expression)
+ queries = {}
+
+ for scope in traverse_scope(expression):
+ query = scope.expression
+ queries[query] = queries.get(query, []) + [query]
+
+ sequence = itertools.count()
+
+ for query, duplicates in queries.items():
+ if len(duplicates) == 1:
+ continue
+
+ alias_ = f"_e_{next(sequence)}"
+
+ for dup in duplicates:
+ parent = dup.parent
+ if isinstance(parent, exp.Subquery):
+ parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
+ elif isinstance(parent, exp.Union):
+ dup.replace(select("*").from_(alias_))
+
+ expression.with_(alias_, as_=query, copy=False)
+
+ return expression
diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py
new file mode 100644
index 0000000..ba562df
--- /dev/null
+++ b/sqlglot/optimizer/expand_multi_table_selects.py
@@ -0,0 +1,16 @@
+from sqlglot import exp
+
+
+def expand_multi_table_selects(expression):
+ for from_ in expression.find_all(exp.From):
+ parent = from_.parent
+
+ for query in from_.expressions[1:]:
+ parent.join(
+ query,
+ join_type="CROSS",
+ copy=False,
+ )
+ from_.expressions.remove(query)
+
+ return expression
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
new file mode 100644
index 0000000..c2e021e
--- /dev/null
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -0,0 +1,31 @@
+from sqlglot import alias, exp
+from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def isolate_table_selects(expression):
+ for scope in traverse_scope(expression):
+ if len(scope.selected_sources) == 1:
+ continue
+
+ for (_, source) in scope.selected_sources.values():
+ if not isinstance(source, exp.Table):
+ continue
+
+ if not isinstance(source.parent, exp.Alias):
+ raise OptimizeError(
+ "Tables require an alias. Run qualify_tables optimization."
+ )
+
+ parent = source.parent
+
+ parent.replace(
+ exp.select("*")
+ .from_(
+ alias(source, source.name or parent.alias, table=True),
+ copy=False,
+ )
+ .subquery(parent.alias, copy=False)
+ )
+
+ return expression
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
new file mode 100644
index 0000000..2c9f89c
--- /dev/null
+++ b/sqlglot/optimizer/normalize.py
@@ -0,0 +1,136 @@
+from sqlglot import exp
+from sqlglot.helper import while_changing
+from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+
+
+def normalize(expression, dnf=False, max_distance=128):
+ """
+ Rewrite sqlglot AST into conjunctive normal form.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("(x AND y) OR z")
+ >>> normalize(expression).sql()
+ '(x OR z) AND (y OR z)'
+
+ Args:
+ expression (sqlglot.Expression): expression to normalize
+ dnf (bool): rewrite in disjunctive normal form instead
+ max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ Returns:
+ sqlglot.Expression: normalized expression
+ """
+ expression = simplify(expression)
+
+ expression = while_changing(
+ expression, lambda e: distributive_law(e, dnf, max_distance)
+ )
+ return simplify(expression)
+
+
+def normalized(expression, dnf=False):
+ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+
+ return not any(
+ connector.find_ancestor(ancestor) for connector in expression.find_all(root)
+ )
+
+
+def normalization_distance(expression, dnf=False):
+ """
+ The difference in the number of predicates between the current expression and the normalized form.
+
+ This is used as an estimate of the cost of the conversion which is exponential in complexity.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
+ >>> normalization_distance(expression)
+ 4
+
+ Args:
+ expression (sqlglot.Expression): expression to compute distance
+ dnf (bool): compute to dnf distance instead
+ Returns:
+ int: difference
+ """
+ return sum(_predicate_lengths(expression, dnf)) - (
+ len(list(expression.find_all(exp.Connector))) + 1
+ )
+
+
+def _predicate_lengths(expression, dnf):
+ """
+ Returns a list of predicate lengths when expanded to normalized form.
+
+ (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
+ """
+ expression = expression.unnest()
+
+ if not isinstance(expression, exp.Connector):
+ return [1]
+
+ left, right = expression.args.values()
+
+ if isinstance(expression, exp.And if dnf else exp.Or):
+ x = [
+ a + b
+ for a in _predicate_lengths(left, dnf)
+ for b in _predicate_lengths(right, dnf)
+ ]
+ return x
+ return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
+
+
+def distributive_law(expression, dnf, max_distance):
+ """
+ x OR (y AND z) -> (x OR y) AND (x OR z)
+ (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
+ """
+ if isinstance(expression.unnest(), exp.Connector):
+ if normalization_distance(expression, dnf) > max_distance:
+ return expression
+
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+
+ if isinstance(expression, from_exp):
+ a, b = expression.unnest_operands()
+
+ from_func = exp.and_ if from_exp == exp.And else exp.or_
+ to_func = exp.and_ if to_exp == exp.And else exp.or_
+
+ if isinstance(a, to_exp) and isinstance(b, to_exp):
+ if len(tuple(a.find_all(exp.Connector))) > len(
+ tuple(b.find_all(exp.Connector))
+ ):
+ return _distribute(a, b, from_func, to_func)
+ return _distribute(b, a, from_func, to_func)
+ if isinstance(a, to_exp):
+ return _distribute(b, a, from_func, to_func)
+ if isinstance(b, to_exp):
+ return _distribute(a, b, from_func, to_func)
+
+ return expression
+
+
+def _distribute(a, b, from_func, to_func):
+ if isinstance(a, exp.Connector):
+ exp.replace_children(
+ a,
+ lambda c: to_func(
+ exp.paren(from_func(c, b.left)),
+ exp.paren(from_func(c, b.right)),
+ ),
+ )
+ else:
+ a = to_func(from_func(a, b.left), from_func(a, b.right))
+
+ return _simplify(a)
+
+
+def _simplify(node):
+ node = uniq_sort(flatten(node))
+ exp.replace_children(node, _simplify)
+ return node
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
new file mode 100644
index 0000000..40e4ab1
--- /dev/null
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -0,0 +1,75 @@
+from sqlglot import exp
+from sqlglot.helper import tsort
+from sqlglot.optimizer.simplify import simplify
+
+
+def optimize_joins(expression):
+ """
+ Removes cross joins if possible and reorder joins based on predicate dependencies.
+ """
+ for select in expression.find_all(exp.Select):
+ references = {}
+ cross_joins = []
+
+ for join in select.args.get("joins", []):
+ name = join.this.alias_or_name
+ tables = other_table_names(join, name)
+
+ if tables:
+ for table in tables:
+ references[table] = references.get(table, []) + [join]
+ else:
+ cross_joins.append((name, join))
+
+ for name, join in cross_joins:
+ for dep in references.get(name, []):
+ on = dep.args["on"]
+ on = on.replace(simplify(on))
+
+ if isinstance(on, exp.Connector):
+ for predicate in on.flatten():
+ if name in exp.column_table_names(predicate):
+ predicate.replace(exp.TRUE)
+ join.on(predicate, copy=False)
+
+ expression = reorder_joins(expression)
+ expression = normalize(expression)
+ return expression
+
+
+def reorder_joins(expression):
+ """
+ Reorder joins by topological sort order based on predicate references.
+ """
+ for from_ in expression.find_all(exp.From):
+ head = from_.expressions[0]
+ parent = from_.parent
+ joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
+ dag = {head.alias_or_name: []}
+
+ for name, join in joins.items():
+ dag[name] = other_table_names(join, name)
+
+ parent.set(
+ "joins",
+ [joins[name] for name in tsort(dag) if name != head.alias_or_name],
+ )
+ return expression
+
+
+def normalize(expression):
+ """
+ Remove INNER and OUTER from joins as they are optional.
+ """
+ for join in expression.find_all(exp.Join):
+ if join.kind != "CROSS":
+ join.set("kind", None)
+ return expression
+
+
+def other_table_names(join, exclude):
+ return [
+ name
+ for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ if name != exclude
+ ]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
new file mode 100644
index 0000000..c03fe3c
--- /dev/null
+++ b/sqlglot/optimizer/optimizer.py
@@ -0,0 +1,43 @@
+from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
+from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
+from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.normalize import normalize
+from sqlglot.optimizer.optimize_joins import optimize_joins
+from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
+from sqlglot.optimizer.pushdown_projections import pushdown_projections
+from sqlglot.optimizer.qualify_columns import qualify_columns
+from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.quote_identities import quote_identities
+from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
+
+
+def optimize(expression, schema=None, db=None, catalog=None):
+ """
+ Rewrite a sqlglot AST into an optimized form.
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ schema (dict|sqlglot.optimizer.Schema): database schema.
+ This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
+ the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ db (str): specify the default database, as might be set by a `USE DATABASE db` statement
+ catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ expression = expression.copy()
+ expression = qualify_tables(expression, db=db, catalog=catalog)
+ expression = isolate_table_selects(expression)
+ expression = qualify_columns(expression, schema)
+ expression = pushdown_projections(expression)
+ expression = normalize(expression)
+ expression = unnest_subqueries(expression)
+ expression = expand_multi_table_selects(expression)
+ expression = pushdown_predicates(expression)
+ expression = optimize_joins(expression)
+ expression = eliminate_subqueries(expression)
+ expression = quote_identities(expression)
+ return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
new file mode 100644
index 0000000..e757322
--- /dev/null
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -0,0 +1,176 @@
+from sqlglot import exp
+from sqlglot.optimizer.normalize import normalized
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def pushdown_predicates(expression):
+ """
+ Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> pushdown_predicates(expression).sql()
+ 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ for scope in reversed(traverse_scope(expression)):
+ select = scope.expression
+ where = select.args.get("where")
+ if where:
+ pushdown(where.this, scope.selected_sources)
+
+ # joins should only pushdown into itself, not to other joins
+ # so we limit the selected sources to only itself
+ for join in select.args.get("joins") or []:
+ name = join.this.alias_or_name
+ pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
+
+ return expression
+
+
+def pushdown(condition, sources):
+ if not condition:
+ return
+
+ condition = condition.replace(simplify(condition))
+ cnf_like = normalized(condition) or not normalized(condition, dnf=True)
+
+ predicates = list(
+ condition.flatten()
+ if isinstance(condition, exp.And if cnf_like else exp.Or)
+ else [condition]
+ )
+
+ if cnf_like:
+ pushdown_cnf(predicates, sources)
+ else:
+ pushdown_dnf(predicates, sources)
+
+
+def pushdown_cnf(predicates, scope):
+ """
+ If the predicates are in CNF like form, we can simply replace each block in the parent.
+ """
+ for predicate in predicates:
+ for node in nodes_for_predicate(predicate, scope).values():
+ if isinstance(node, exp.Join):
+ predicate.replace(exp.TRUE)
+ node.on(predicate, copy=False)
+ break
+ if isinstance(node, exp.Select):
+ predicate.replace(exp.TRUE)
+ node.where(replace_aliases(node, predicate), copy=False)
+
+
+def pushdown_dnf(predicates, scope):
+ """
+ If the predicates are in DNF form, we can only push down conditions that are in all blocks.
+ Additionally, we can't remove predicates from their original form.
+ """
+ # find all the tables that can be pushdown too
+ # these are tables that are referenced in all blocks of a DNF
+ # (a.x AND b.x) OR (a.y AND c.y)
+ # only table a can be push down
+ pushdown_tables = set()
+
+ for a in predicates:
+ a_tables = set(exp.column_table_names(a))
+
+ for b in predicates:
+ a_tables &= set(exp.column_table_names(b))
+
+ pushdown_tables.update(a_tables)
+
+ conditions = {}
+
+ # for every pushdown table, find all related conditions in all predicates
+ # combine them with ORS
+ # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
+ for table in sorted(pushdown_tables):
+ for predicate in predicates:
+ nodes = nodes_for_predicate(predicate, scope)
+
+ if table not in nodes:
+ continue
+
+ predicate_condition = None
+
+ for column in predicate.find_all(exp.Column):
+ if column.table == table:
+ condition = column.find_ancestor(exp.Condition)
+ predicate_condition = (
+ exp.and_(predicate_condition, condition)
+ if predicate_condition
+ else condition
+ )
+
+ if predicate_condition:
+ conditions[table] = (
+ exp.or_(conditions[table], predicate_condition)
+ if table in conditions
+ else predicate_condition
+ )
+
+ for name, node in nodes.items():
+ if name not in conditions:
+ continue
+
+ predicate = conditions[name]
+
+ if isinstance(node, exp.Join):
+ node.on(predicate, copy=False)
+ elif isinstance(node, exp.Select):
+ node.where(replace_aliases(node, predicate), copy=False)
+
+
+def nodes_for_predicate(predicate, sources):
+ nodes = {}
+ tables = exp.column_table_names(predicate)
+ where_condition = isinstance(
+ predicate.find_ancestor(exp.Join, exp.Where), exp.Where
+ )
+
+ for table in tables:
+ node, source = sources.get(table) or (None, None)
+
+ # if the predicate is in a where statement we can try to push it down
+ # we want to find the root join or from statement
+ if node and where_condition:
+ node = node.find_ancestor(exp.Join, exp.From)
+
+ # a node can reference a CTE which should be push down
+ if isinstance(node, exp.From) and not isinstance(source, exp.Table):
+ node = source.expression
+
+ if isinstance(node, exp.Join):
+ if node.side:
+ return {}
+ nodes[table] = node
+ elif isinstance(node, exp.Select) and len(tables) == 1:
+ if not node.args.get("group"):
+ nodes[table] = node
+ return nodes
+
+
+def replace_aliases(source, predicate):
+ aliases = {}
+
+ for select in source.selects:
+ if isinstance(select, exp.Alias):
+ aliases[select.alias] = select.this
+ else:
+ aliases[select.name] = select
+
+ def _replace_alias(column):
+ if isinstance(column, exp.Column) and column.name in aliases:
+ return aliases[column.name]
+ return column
+
+ return predicate.transform(_replace_alias)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
new file mode 100644
index 0000000..097ce04
--- /dev/null
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -0,0 +1,85 @@
+from collections import defaultdict
+
+from sqlglot import alias, exp
+from sqlglot.optimizer.scope import Scope, traverse_scope
+
+# Sentinel value that means an outer query selecting ALL columns
+SELECT_ALL = object()
+
+
+def pushdown_projections(expression):
+ """
+ Rewrite sqlglot AST to remove unused columns projections.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> pushdown_projections(expression).sql()
+ 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ # Map of Scope to all columns being selected by outer queries.
+ referenced_columns = defaultdict(set)
+
+ # We build the scope tree (which is traversed in DFS postorder), then iterate
+ # over the result in reverse order. This should ensure that the set of selected
+ # columns for a particular scope are completely build by the time we get to it.
+ for scope in reversed(traverse_scope(expression)):
+ parent_selections = referenced_columns.get(scope, {SELECT_ALL})
+
+ if scope.expression.args.get("distinct"):
+ # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
+ parent_selections = {SELECT_ALL}
+
+ if isinstance(scope.expression, exp.Union):
+ left, right = scope.union
+ referenced_columns[left] = parent_selections
+ referenced_columns[right] = parent_selections
+
+ if isinstance(scope.expression, exp.Select):
+ _remove_unused_selections(scope, parent_selections)
+
+ # Group columns by source name
+ selects = defaultdict(set)
+ for col in scope.columns:
+ table_name = col.table
+ col_name = col.name
+ selects[table_name].add(col_name)
+
+ # Push the selected columns down to the next scope
+ for name, (_, source) in scope.selected_sources.items():
+ if isinstance(source, Scope):
+ columns = selects.get(name) or set()
+ referenced_columns[source].update(columns)
+
+ return expression
+
+
+def _remove_unused_selections(scope, parent_selections):
+ order = scope.expression.args.get("order")
+
+ if order:
+ # Assume columns without a qualified table are references to output columns
+ order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
+ else:
+ order_refs = set()
+
+ new_selections = []
+ for selection in scope.selects:
+ if (
+ SELECT_ALL in parent_selections
+ or selection.alias_or_name in parent_selections
+ or selection.alias_or_name in order_refs
+ ):
+ new_selections.append(selection)
+
+ # If there are no remaining selections, just select a single constant
+ if not new_selections:
+ new_selections.append(alias("1", "_"))
+
+ scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
new file mode 100644
index 0000000..394f49e
--- /dev/null
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -0,0 +1,422 @@
+import itertools
+
+from sqlglot import alias, exp
+from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.schema import ensure_schema
+from sqlglot.optimizer.scope import traverse_scope
+
+SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
+
+
+def qualify_columns(expression, schema):
+ """
+ Rewrite sqlglot AST to have fully qualified columns.
+
+ Example:
+ >>> import sqlglot
+ >>> schema = {"tbl": {"col": "INT"}}
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+ >>> qualify_columns(expression, schema).sql()
+ 'SELECT tbl.col AS col FROM tbl'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ schema (dict|sqlglot.optimizer.Schema): Database schema
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ schema = ensure_schema(schema)
+
+ for scope in traverse_scope(expression):
+ resolver = _Resolver(scope, schema)
+ _pop_table_column_aliases(scope.ctes)
+ _pop_table_column_aliases(scope.derived_tables)
+ _expand_using(scope, resolver)
+ _expand_group_by(scope, resolver)
+ _expand_order_by(scope)
+ _qualify_columns(scope, resolver)
+ if not isinstance(scope.expression, SKIP_QUALIFY):
+ _expand_stars(scope, resolver)
+ _qualify_outputs(scope)
+ _check_unknown_tables(scope)
+
+ return expression
+
+
+def _pop_table_column_aliases(derived_tables):
+ """
+ Remove table column aliases.
+
+ (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
+ """
+ for derived_table in derived_tables:
+ if isinstance(derived_table, SKIP_QUALIFY):
+ continue
+ table_alias = derived_table.args.get("alias")
+ if table_alias:
+ table_alias.args.pop("columns", None)
+
+
+def _expand_using(scope, resolver):
+ joins = list(scope.expression.find_all(exp.Join))
+ names = {join.this.alias for join in joins}
+ ordered = [key for key in scope.selected_sources if key not in names]
+
+ # Mapping of automatically joined column names to source names
+ column_tables = {}
+
+ for join in joins:
+ using = join.args.get("using")
+
+ if not using:
+ continue
+
+ join_table = join.this.alias_or_name
+
+ columns = {}
+
+ for k in scope.selected_sources:
+ if k in ordered:
+ for column in resolver.get_source_columns(k):
+ if column not in columns:
+ columns[column] = k
+
+ ordered.append(join_table)
+ join_columns = resolver.get_source_columns(join_table)
+ conditions = []
+
+ for identifier in using:
+ identifier = identifier.name
+ table = columns.get(identifier)
+
+ if not table or identifier not in join_columns:
+ raise OptimizeError(f"Cannot automatically join: {identifier}")
+
+ conditions.append(
+ exp.condition(
+ exp.EQ(
+ this=exp.column(identifier, table=table),
+ expression=exp.column(identifier, table=join_table),
+ )
+ )
+ )
+
+ tables = column_tables.setdefault(identifier, [])
+ if table not in tables:
+ tables.append(table)
+ if join_table not in tables:
+ tables.append(join_table)
+
+ join.args.pop("using")
+ join.set("on", exp.and_(*conditions))
+
+ if column_tables:
+ for column in scope.columns:
+ if not column.table and column.name in column_tables:
+ tables = column_tables[column.name]
+ coalesce = [exp.column(column.name, table=table) for table in tables]
+ replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
+
+ # Ensure selects keep their output name
+ if isinstance(column.parent, exp.Select):
+ replacement = exp.alias_(replacement, alias=column.name)
+
+ scope.replace(column, replacement)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
+ # Replace references to select aliases
+ def transform(node, *_):
+ if isinstance(node, exp.Column) and not node.table:
+ table = resolver.get_table(node.name)
+
+ # Source columns get priority over select aliases
+ if table:
+ node.set("table", exp.to_identifier(table))
+ return node
+
+ selects = {s.alias_or_name: s for s in scope.selects}
+
+ select = selects.get(node.name)
+ if select:
+ scope.clear_cache()
+ if isinstance(select, exp.Alias):
+ select = select.this
+ return select.copy()
+
+ return node
+
+ group.transform(transform, copy=False)
+ group.set("expressions", _expand_positional_references(scope, group.expressions))
+ scope.expression.set("group", group)
+
+
+def _expand_order_by(scope):
+ order = scope.expression.args.get("order")
+ if not order:
+ return
+
+ ordereds = order.expressions
+ for ordered, new_expression in zip(
+ ordereds,
+ _expand_positional_references(scope, (o.this for o in ordereds)),
+ ):
+ ordered.set("this", new_expression)
+
+
+def _expand_positional_references(scope, expressions):
+ new_nodes = []
+ for node in expressions:
+ if node.is_int:
+ try:
+ select = scope.selects[int(node.name) - 1]
+ except IndexError:
+ raise OptimizeError(f"Unknown output column: {node.name}")
+ if isinstance(select, exp.Alias):
+ select = select.this
+ new_nodes.append(select.copy())
+ scope.clear_cache()
+ else:
+ new_nodes.append(node)
+
+ return new_nodes
+
+
+def _qualify_columns(scope, resolver):
+ """Disambiguate columns, ensuring each column specifies a source"""
+ for column in scope.columns:
+ column_table = column.table
+ column_name = column.name
+
+ if (
+ column_table
+ and column_table in scope.sources
+ and column_name not in resolver.get_source_columns(column_table)
+ ):
+ raise OptimizeError(f"Unknown column: {column_name}")
+
+ if not column_table:
+ column_table = resolver.get_table(column_name)
+
+ if not scope.is_subquery and not scope.is_unnest:
+ if column_name not in resolver.all_columns:
+ raise OptimizeError(f"Unknown column: {column_name}")
+
+ if column_table is None:
+ raise OptimizeError(f"Ambiguous column: {column_name}")
+
+ # column_table can be a '' because bigquery unnest has no table alias
+ if column_table:
+ column.set("table", exp.to_identifier(column_table))
+
+
+def _expand_stars(scope, resolver):
+ """Expand stars to lists of column selections"""
+
+ new_selections = []
+ except_columns = {}
+ replace_columns = {}
+
+ for expression in scope.selects:
+ if isinstance(expression, exp.Star):
+ tables = list(scope.selected_sources)
+ _add_except_columns(expression, tables, except_columns)
+ _add_replace_columns(expression, tables, replace_columns)
+ elif isinstance(expression, exp.Column) and isinstance(
+ expression.this, exp.Star
+ ):
+ tables = [expression.table]
+ _add_except_columns(expression.this, tables, except_columns)
+ _add_replace_columns(expression.this, tables, replace_columns)
+ else:
+ new_selections.append(expression)
+ continue
+
+ for table in tables:
+ if table not in scope.sources:
+ raise OptimizeError(f"Unknown table: {table}")
+ columns = resolver.get_source_columns(table)
+ table_id = id(table)
+ for name in columns:
+ if name not in except_columns.get(table_id, set()):
+ alias_ = replace_columns.get(table_id, {}).get(name, name)
+ column = exp.column(name, table)
+ new_selections.append(
+ alias(column, alias_) if alias_ != name else column
+ )
+
+ scope.expression.set("expressions", new_selections)
+
+
+def _add_except_columns(expression, tables, except_columns):
+ except_ = expression.args.get("except")
+
+ if not except_:
+ return
+
+ columns = {e.name for e in except_}
+
+ for table in tables:
+ except_columns[id(table)] = columns
+
+
+def _add_replace_columns(expression, tables, replace_columns):
+ replace = expression.args.get("replace")
+
+ if not replace:
+ return
+
+ columns = {e.this.name: e.alias for e in replace}
+
+ for table in tables:
+ replace_columns[id(table)] = columns
+
+
+def _qualify_outputs(scope):
+ """Ensure all output columns are aliased"""
+ new_selections = []
+
+ for i, (selection, aliased_column) in enumerate(
+ itertools.zip_longest(scope.selects, scope.outer_column_list)
+ ):
+ if isinstance(selection, exp.Column):
+ # convoluted setter because a simple selection.replace(alias) would require a copy
+ alias_ = alias(exp.column(""), alias=selection.name)
+ alias_.set("this", selection)
+ selection = alias_
+ elif not isinstance(selection, exp.Alias):
+ alias_ = alias(exp.column(""), f"_col_{i}")
+ alias_.set("this", selection)
+ selection = alias_
+
+ if aliased_column:
+ selection.set("alias", exp.to_identifier(aliased_column))
+
+ new_selections.append(selection)
+
+ scope.expression.set("expressions", new_selections)
+
+
+def _check_unknown_tables(scope):
+ if (
+ scope.external_columns
+ and not scope.is_unnest
+ and not scope.is_correlated_subquery
+ ):
+ raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
+
+
+class _Resolver:
+ """
+ Helper for resolving columns.
+
+ This is a class so we can lazily load some things and easily share them across functions.
+ """
+
+ def __init__(self, scope, schema):
+ self.scope = scope
+ self.schema = schema
+ self._source_columns = None
+ self._unambiguous_columns = None
+ self._all_columns = None
+
+ def get_table(self, column_name):
+ """
+ Get the table for a column name.
+
+ Args:
+ column_name (str)
+ Returns:
+ (str) table name
+ """
+ if self._unambiguous_columns is None:
+ self._unambiguous_columns = self._get_unambiguous_columns(
+ self._get_all_source_columns()
+ )
+ return self._unambiguous_columns.get(column_name)
+
+ @property
+ def all_columns(self):
+ """All available columns of all sources in this scope"""
+ if self._all_columns is None:
+ self._all_columns = set(
+ column
+ for columns in self._get_all_source_columns().values()
+ for column in columns
+ )
+ return self._all_columns
+
+ def get_source_columns(self, name):
+ """Resolve the source columns for a given source `name`"""
+ if name not in self.scope.sources:
+ raise OptimizeError(f"Unknown table: {name}")
+
+ source = self.scope.sources[name]
+
+ # If referencing a table, return the columns from the schema
+ if isinstance(source, exp.Table):
+ try:
+ return self.schema.column_names(source)
+ except Exception as e:
+ raise OptimizeError(str(e)) from e
+
+ # Otherwise, if referencing another scope, return that scope's named selects
+ return source.expression.named_selects
+
+ def _get_all_source_columns(self):
+ if self._source_columns is None:
+ self._source_columns = {
+ k: self.get_source_columns(k) for k in self.scope.selected_sources
+ }
+ return self._source_columns
+
+ def _get_unambiguous_columns(self, source_columns):
+ """
+ Find all the unambiguous columns in sources.
+
+ Args:
+ source_columns (dict): Mapping of names to source columns
+ Returns:
+ dict: Mapping of column name to source name
+ """
+ if not source_columns:
+ return {}
+
+ source_columns = list(source_columns.items())
+
+ first_table, first_columns = source_columns[0]
+ unambiguous_columns = {
+ col: first_table for col in self._find_unique_columns(first_columns)
+ }
+ all_columns = set(unambiguous_columns)
+
+ for table, columns in source_columns[1:]:
+ unique = self._find_unique_columns(columns)
+ ambiguous = set(all_columns).intersection(unique)
+ all_columns.update(columns)
+ for column in ambiguous:
+ unambiguous_columns.pop(column, None)
+ for column in unique.difference(ambiguous):
+ unambiguous_columns[column] = table
+
+ return unambiguous_columns
+
+ @staticmethod
+ def _find_unique_columns(columns):
+ """
+ Find the unique columns in a list of columns.
+
+ Example:
+ >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
+ ['a', 'c']
+
+ This is necessary because duplicate column names are ambiguous.
+ """
+ counts = {}
+ for column in columns:
+ counts[column] = counts.get(column, 0) + 1
+ return {column for column, count in counts.items() if count == 1}
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
new file mode 100644
index 0000000..9f8b9f5
--- /dev/null
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -0,0 +1,54 @@
+import itertools
+
+from sqlglot import alias, exp
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def qualify_tables(expression, db=None, catalog=None):
+ """
+ Rewrite sqlglot AST to have fully qualified tables.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
+ >>> qualify_tables(expression, db="db").sql()
+ 'SELECT 1 FROM db.tbl AS tbl'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ db (str): Database name
+ catalog (str): Catalog name
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ sequence = itertools.count()
+
+ for scope in traverse_scope(expression):
+ for derived_table in scope.ctes + scope.derived_tables:
+ if not derived_table.args.get("alias"):
+ alias_ = f"_q_{next(sequence)}"
+ derived_table.set(
+ "alias", exp.TableAlias(this=exp.to_identifier(alias_))
+ )
+ scope.rename_source(None, alias_)
+
+ for source in scope.sources.values():
+ if isinstance(source, exp.Table):
+ identifier = isinstance(source.this, exp.Identifier)
+
+ if identifier:
+ if not source.args.get("db"):
+ source.set("db", exp.to_identifier(db))
+ if not source.args.get("catalog"):
+ source.set("catalog", exp.to_identifier(catalog))
+
+ if not isinstance(source.parent, exp.Alias):
+ source.replace(
+ alias(
+ source.copy(),
+ source.this if identifier else f"_q_{next(sequence)}",
+ table=True,
+ )
+ )
+
+ return expression
diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py
new file mode 100644
index 0000000..17623cc
--- /dev/null
+++ b/sqlglot/optimizer/quote_identities.py
@@ -0,0 +1,25 @@
+from sqlglot import exp
+
+
+def quote_identities(expression):
+ """
+ Rewrite sqlglot AST to ensure all identities are quoted.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
+ >>> quote_identities(expression).sql()
+ 'SELECT "x"."a" AS "a" FROM "db"."x"'
+
+ Args:
+ expression (sqlglot.Expression): expression to quote
+ Returns:
+ sqlglot.Expression: quoted expression
+ """
+
+ def qualify(node):
+ if isinstance(node, exp.Identifier):
+ node.set("quoted", True)
+ return node
+
+ return expression.transform(qualify, copy=False)
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
new file mode 100644
index 0000000..9968108
--- /dev/null
+++ b/sqlglot/optimizer/schema.py
@@ -0,0 +1,129 @@
+import abc
+
+from sqlglot import exp
+from sqlglot.errors import OptimizeError
+from sqlglot.helper import csv_reader
+
+
+class Schema(abc.ABC):
+ """Abstract base class for database schemas"""
+
+ @abc.abstractmethod
+ def column_names(self, table):
+ """
+ Get the column names for a table.
+
+ Args:
+ table (sqlglot.expressions.Table): Table expression instance
+ Returns:
+ list[str]: list of column names
+ """
+
+
+class MappingSchema(Schema):
+ """
+ Schema based on a nested mapping.
+
+ Args:
+ schema (dict): Mapping in one of the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ """
+
+ def __init__(self, schema):
+ self.schema = schema
+
+ depth = _dict_depth(schema)
+
+ if not depth: # {}
+ self.supported_table_args = []
+ elif depth == 2: # {table: {col: type}}
+ self.supported_table_args = ("this",)
+ elif depth == 3: # {db: {table: {col: type}}}
+ self.supported_table_args = ("db", "this")
+ elif depth == 4: # {catalog: {db: {table: {col: type}}}}
+ self.supported_table_args = ("catalog", "db", "this")
+ else:
+ raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
+
+ self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
+
+ def column_names(self, table):
+ if not isinstance(table.this, exp.Identifier):
+ return fs_get(table)
+
+ args = tuple(table.text(p) for p in self.supported_table_args)
+
+ for forbidden in self.forbidden_args:
+ if table.text(forbidden):
+ raise ValueError(
+ f"Schema doesn't support {forbidden}. Received: {table.sql()}"
+ )
+ return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
+
+
+def ensure_schema(schema):
+ if isinstance(schema, Schema):
+ return schema
+
+ return MappingSchema(schema)
+
+
+def fs_get(table):
+ name = table.this.name.upper()
+
+ if name.upper() == "READ_CSV":
+ with csv_reader(table) as reader:
+ return next(reader)
+
+ raise ValueError(f"Cannot read schema for {table}")
+
+
+def _nested_get(d, *path):
+ """
+ Get a value for a nested dictionary.
+
+ Args:
+ d (dict): dictionary
+ *path (tuple[str, str]): tuples of (name, key)
+ `key` is the key in the dictionary to get.
+ `name` is a string to use in the error if `key` isn't found.
+ """
+ for name, key in path:
+ d = d.get(key)
+ if d is None:
+ name = "table" if name == "this" else name
+ raise ValueError(f"Unknown {name}")
+ return d
+
+
+def _dict_depth(d):
+ """
+ Get the nesting depth of a dictionary.
+
+ For example:
+ >>> _dict_depth(None)
+ 0
+ >>> _dict_depth({})
+ 1
+ >>> _dict_depth({"a": "b"})
+ 1
+ >>> _dict_depth({"a": {}})
+ 2
+ >>> _dict_depth({"a": {"b": {}}})
+ 3
+
+ Args:
+ d (dict): dictionary
+ Returns:
+ int: depth
+ """
+ try:
+ return 1 + _dict_depth(next(iter(d.values())))
+ except AttributeError:
+ # d doesn't have attribute "values"
+ return 0
+ except StopIteration:
+ # d.values() returns an empty sequence
+ return 1
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
new file mode 100644
index 0000000..f6f59e8
--- /dev/null
+++ b/sqlglot/optimizer/scope.py
@@ -0,0 +1,438 @@
+from copy import copy
+from enum import Enum, auto
+
+from sqlglot import exp
+from sqlglot.errors import OptimizeError
+
+
+class ScopeType(Enum):
+ ROOT = auto()
+ SUBQUERY = auto()
+ DERIVED_TABLE = auto()
+ CTE = auto()
+ UNION = auto()
+ UNNEST = auto()
+
+
+class Scope:
+ """
+ Selection scope.
+
+ Attributes:
+ expression (exp.Select|exp.Union): Root expression of this scope
+ sources (dict[str, exp.Table|Scope]): Mapping of source name to either
+ a Table expression or another Scope instance. For example:
+ SELECT * FROM x {"x": Table(this="x")}
+ SELECT * FROM x AS y {"y": Table(this="x")}
+ SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
+ outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
+ defines a column list of it's alias of this scope, this is that list of columns.
+ For example:
+ SELECT * FROM (SELECT ...) AS y(col1, col2)
+ The inner query would have `["col1", "col2"]` for its `outer_column_list`
+ parent (Scope): Parent scope
+ scope_type (ScopeType): Type of this scope, relative to it's parent
+ subquery_scopes (list[Scope]): List of all child scopes for subqueries.
+ This does not include derived tables or CTEs.
+ union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
+ a tuple of the left and right child scopes.
+ """
+
+ def __init__(
+ self,
+ expression,
+ sources=None,
+ outer_column_list=None,
+ parent=None,
+ scope_type=ScopeType.ROOT,
+ ):
+ self.expression = expression
+ self.sources = sources or {}
+ self.outer_column_list = outer_column_list or []
+ self.parent = parent
+ self.scope_type = scope_type
+ self.subquery_scopes = []
+ self.union = None
+ self.clear_cache()
+
+ def clear_cache(self):
+ self._collected = False
+ self._raw_columns = None
+ self._derived_tables = None
+ self._tables = None
+ self._ctes = None
+ self._subqueries = None
+ self._selected_sources = None
+ self._columns = None
+ self._external_columns = None
+
+ def branch(self, expression, scope_type, add_sources=None, **kwargs):
+ """Branch from the current scope to a new, inner scope"""
+ sources = copy(self.sources)
+ if add_sources:
+ sources.update(add_sources)
+ return Scope(
+ expression=expression.unnest(),
+ sources=sources,
+ parent=self,
+ scope_type=scope_type,
+ **kwargs,
+ )
+
+ def _collect(self):
+ self._tables = []
+ self._ctes = []
+ self._subqueries = []
+ self._derived_tables = []
+ self._raw_columns = []
+
+ # We'll use this variable to pass state into the dfs generator.
+ # Whenever we set it to True, we exclude a subtree from traversal.
+ prune = False
+
+ for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
+ prune = False
+
+ if node is self.expression:
+ continue
+ if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+ self._raw_columns.append(node)
+ elif isinstance(node, exp.Table):
+ self._tables.append(node)
+ elif isinstance(node, (exp.Unnest, exp.Lateral)):
+ self._derived_tables.append(node)
+ elif isinstance(node, exp.CTE):
+ self._ctes.append(node)
+ prune = True
+ elif isinstance(node, exp.Subquery) and isinstance(
+ parent, (exp.From, exp.Join)
+ ):
+ self._derived_tables.append(node)
+ prune = True
+ elif isinstance(node, exp.Subqueryable):
+ self._subqueries.append(node)
+ prune = True
+
+ self._collected = True
+
+ def _ensure_collected(self):
+ if not self._collected:
+ self._collect()
+
+ def replace(self, old, new):
+ """
+ Replace `old` with `new`.
+
+ This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
+
+ Args:
+ old (exp.Expression): old node
+ new (exp.Expression): new node
+ """
+ old.replace(new)
+ self.clear_cache()
+
+ @property
+ def tables(self):
+ """
+ List of tables in this scope.
+
+ Returns:
+ list[exp.Table]: tables
+ """
+ self._ensure_collected()
+ return self._tables
+
+ @property
+ def ctes(self):
+ """
+ List of CTEs in this scope.
+
+ Returns:
+ list[exp.CTE]: ctes
+ """
+ self._ensure_collected()
+ return self._ctes
+
+ @property
+ def derived_tables(self):
+ """
+ List of derived tables in this scope.
+
+ For example:
+ SELECT * FROM (SELECT ...) <- that's a derived table
+
+ Returns:
+ list[exp.Subquery]: derived tables
+ """
+ self._ensure_collected()
+ return self._derived_tables
+
+ @property
+ def subqueries(self):
+ """
+ List of subqueries in this scope.
+
+ For example:
+ SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
+
+ Returns:
+ list[exp.Subqueryable]: subqueries
+ """
+ self._ensure_collected()
+ return self._subqueries
+
+ @property
+ def columns(self):
+ """
+ List of columns in this scope.
+
+ Returns:
+ list[exp.Column]: Column instances in this scope, plus any
+ Columns that reference this scope from correlated subqueries.
+ """
+ if self._columns is None:
+ self._ensure_collected()
+ columns = self._raw_columns
+
+ external_columns = [
+ column
+ for scope in self.subquery_scopes
+ for column in scope.external_columns
+ ]
+
+ named_outputs = {e.alias_or_name for e in self.expression.expressions}
+
+ self._columns = [
+ c
+ for c in columns + external_columns
+ if not (
+ c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
+ )
+ ]
+ return self._columns
+
+ @property
+ def selected_sources(self):
+ """
+ Mapping of nodes and sources that are actually selected from in this scope.
+
+ That is, all tables in a schema are selectable at any point. But a
+ table only becomes a selected source if it's included in a FROM or JOIN clause.
+
+ Returns:
+ dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
+ """
+ if self._selected_sources is None:
+ referenced_names = []
+
+ for table in self.tables:
+ referenced_names.append(
+ (
+ table.parent.alias
+ if isinstance(table.parent, exp.Alias)
+ else table.name,
+ table,
+ )
+ )
+ for derived_table in self.derived_tables:
+ referenced_names.append((derived_table.alias, derived_table.unnest()))
+
+ result = {}
+
+ for name, node in referenced_names:
+ if name in self.sources:
+ result[name] = (node, self.sources[name])
+
+ self._selected_sources = result
+ return self._selected_sources
+
+ @property
+ def selects(self):
+ """
+ Select expressions of this scope.
+
+ For example, for the following expression:
+ SELECT 1 as a, 2 as b FROM x
+
+ The outputs are the "1 as a" and "2 as b" expressions.
+
+ Returns:
+ list[exp.Expression]: expressions
+ """
+ if isinstance(self.expression, exp.Union):
+ return []
+ return self.expression.selects
+
+ @property
+ def external_columns(self):
+ """
+ Columns that appear to reference sources in outer scopes.
+
+ Returns:
+ list[exp.Column]: Column instances that don't reference
+ sources in the current scope.
+ """
+ if self._external_columns is None:
+ self._external_columns = [
+ c for c in self.columns if c.table not in self.selected_sources
+ ]
+ return self._external_columns
+
+ def source_columns(self, source_name):
+ """
+ Get all columns in the current scope for a particular source.
+
+ Args:
+ source_name (str): Name of the source
+ Returns:
+ list[exp.Column]: Column instances that reference `source_name`
+ """
+ return [column for column in self.columns if column.table == source_name]
+
+ @property
+ def is_subquery(self):
+ """Determine if this scope is a subquery"""
+ return self.scope_type == ScopeType.SUBQUERY
+
+ @property
+ def is_unnest(self):
+ """Determine if this scope is an unnest"""
+ return self.scope_type == ScopeType.UNNEST
+
+ @property
+ def is_correlated_subquery(self):
+ """Determine if this scope is a correlated subquery"""
+ return bool(self.is_subquery and self.external_columns)
+
+ def rename_source(self, old_name, new_name):
+ """Rename a source in this scope"""
+ columns = self.sources.pop(old_name or "", [])
+ self.sources[new_name] = columns
+
+
+def traverse_scope(expression):
+ """
+ Traverse an expression by it's "scopes".
+
+ "Scope" represents the current context of a Select statement.
+
+ This is helpful for optimizing queries, where we need more information than
+ the expression tree itself. For example, we might care about the source
+ names within a subquery. Returns a list because a generator could result in
+ incomplete properties which is confusing.
+
+ Examples:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
+ >>> scopes = traverse_scope(expression)
+ >>> scopes[0].expression.sql(), list(scopes[0].sources)
+ ('SELECT a FROM x', ['x'])
+ >>> scopes[1].expression.sql(), list(scopes[1].sources)
+ ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
+
+ Args:
+ expression (exp.Expression): expression to traverse
+ Returns:
+ List[Scope]: scope instances
+ """
+ return list(_traverse_scope(Scope(expression)))
+
+
+def _traverse_scope(scope):
+ if isinstance(scope.expression, exp.Select):
+ yield from _traverse_select(scope)
+ elif isinstance(scope.expression, exp.Union):
+ yield from _traverse_union(scope)
+ elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
+ pass
+ elif isinstance(scope.expression, exp.Subquery):
+ yield from _traverse_subqueries(scope)
+ else:
+ raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
+ yield scope
+
+
+def _traverse_select(scope):
+ yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
+ yield from _traverse_subqueries(scope)
+ yield from _traverse_derived_tables(
+ scope.derived_tables, scope, ScopeType.DERIVED_TABLE
+ )
+ _add_table_sources(scope)
+
+
+def _traverse_union(scope):
+ yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
+
+ # The last scope to be yield should be the top most scope
+ left = None
+ for left in _traverse_scope(
+ scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
+ ):
+ yield left
+
+ right = None
+ for right in _traverse_scope(
+ scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
+ ):
+ yield right
+
+ scope.union = (left, right)
+
+
+def _traverse_derived_tables(derived_tables, scope, scope_type):
+ sources = {}
+
+ for derived_table in derived_tables:
+ for child_scope in _traverse_scope(
+ scope.branch(
+ derived_table
+ if isinstance(derived_table, (exp.Unnest, exp.Lateral))
+ else derived_table.this,
+ add_sources=sources if scope_type == ScopeType.CTE else None,
+ outer_column_list=derived_table.alias_column_names,
+ scope_type=ScopeType.UNNEST
+ if isinstance(derived_table, exp.Unnest)
+ else scope_type,
+ )
+ ):
+ yield child_scope
+ # Tables without aliases will be set as ""
+ # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
+ # Until then, this means that only a single, unaliased derived table is allowed (rather,
+ # the latest one wins.
+ sources[derived_table.alias] = child_scope
+ scope.sources.update(sources)
+
+
+def _add_table_sources(scope):
+ sources = {}
+ for table in scope.tables:
+ table_name = table.name
+
+ if isinstance(table.parent, exp.Alias):
+ source_name = table.parent.alias
+ else:
+ source_name = table_name
+
+ if table_name in scope.sources:
+ # This is a reference to a parent source (e.g. a CTE), not an actual table.
+ scope.sources[source_name] = scope.sources[table_name]
+ elif source_name in scope.sources:
+ raise OptimizeError(f"Duplicate table name: {source_name}")
+ else:
+ sources[source_name] = table
+
+ scope.sources.update(sources)
+
+
+def _traverse_subqueries(scope):
+ for subquery in scope.subqueries:
+ top = None
+ for child_scope in _traverse_scope(
+ scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
+ ):
+ yield child_scope
+ top = child_scope
+ scope.subquery_scopes.append(top)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
new file mode 100644
index 0000000..6771153
--- /dev/null
+++ b/sqlglot/optimizer/simplify.py
@@ -0,0 +1,383 @@
+import datetime
+import functools
+import itertools
+from collections import deque
+from decimal import Decimal
+
+from sqlglot import exp
+from sqlglot.expressions import FALSE, NULL, TRUE
+from sqlglot.generator import Generator
+from sqlglot.helper import while_changing
+
+GENERATOR = Generator(normalize=True, identify=True)
+
+
+def simplify(expression):
+ """
+ Rewrite sqlglot AST to simplify expressions.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("TRUE AND TRUE")
+ >>> simplify(expression).sql()
+ 'TRUE'
+
+ Args:
+ expression (sqlglot.Expression): expression to simplify
+ Returns:
+ sqlglot.Expression: simplified expression
+ """
+
+ def _simplify(expression, root=True):
+ node = expression
+ node = uniq_sort(node)
+ node = absorb_and_eliminate(node)
+ exp.replace_children(node, lambda e: _simplify(e, False))
+ node = simplify_not(node)
+ node = flatten(node)
+ node = simplify_connectors(node)
+ node = remove_compliments(node)
+ node.parent = expression.parent
+ node = simplify_literals(node)
+ node = simplify_parens(node)
+ if root:
+ expression.replace(node)
+ return node
+
+ expression = while_changing(expression, _simplify)
+ remove_where_true(expression)
+ return expression
+
+
+def simplify_not(expression):
+ """
+ Demorgan's Law
+ NOT (x OR y) -> NOT x AND NOT y
+ NOT (x AND y) -> NOT x OR NOT y
+ """
+ if isinstance(expression, exp.Not):
+ if isinstance(expression.this, exp.Paren):
+ condition = expression.this.unnest()
+ if isinstance(condition, exp.And):
+ return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
+ if isinstance(condition, exp.Or):
+ return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
+ if always_true(expression.this):
+ return FALSE
+ if expression.this == FALSE:
+ return TRUE
+ if isinstance(expression.this, exp.Not):
+ # double negation
+ # NOT NOT x -> x
+ return expression.this.this
+ return expression
+
+
+def flatten(expression):
+ """
+ A AND (B AND C) -> A AND B AND C
+ A OR (B OR C) -> A OR B OR C
+ """
+ if isinstance(expression, exp.Connector):
+ for node in expression.args.values():
+ child = node.unnest()
+ if isinstance(child, expression.__class__):
+ node.replace(child)
+ return expression
+
+
+def simplify_connectors(expression):
+ if isinstance(expression, exp.Connector):
+ left = expression.left
+ right = expression.right
+
+ if left == right:
+ return left
+
+ if isinstance(expression, exp.And):
+ if NULL in (left, right):
+ return NULL
+ if FALSE in (left, right):
+ return FALSE
+ if always_true(left) and always_true(right):
+ return TRUE
+ if always_true(left):
+ return right
+ if always_true(right):
+ return left
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return TRUE
+ if left == FALSE and right == FALSE:
+ return FALSE
+ if (
+ (left == NULL and right == NULL)
+ or (left == NULL and right == FALSE)
+ or (left == FALSE and right == NULL)
+ ):
+ return NULL
+ if left == FALSE:
+ return right
+ if right == FALSE:
+ return left
+ return expression
+
+
+def remove_compliments(expression):
+ """
+ Removing compliments.
+
+ A AND NOT A -> FALSE
+ A OR NOT A -> TRUE
+ """
+ if isinstance(expression, exp.Connector):
+ compliment = FALSE if isinstance(expression, exp.And) else TRUE
+
+ for a, b in itertools.permutations(expression.flatten(), 2):
+ if is_complement(a, b):
+ return compliment
+ return expression
+
+
+def uniq_sort(expression):
+ """
+ Uniq and sort a connector.
+
+ C AND A AND B AND B -> A AND B AND C
+ """
+ if isinstance(expression, exp.Connector):
+ result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
+ flattened = tuple(expression.flatten())
+ deduped = {GENERATOR.generate(e): e for e in flattened}
+ arr = tuple(deduped.items())
+
+ # check if the operands are already sorted, if not sort them
+ # A AND C AND B -> A AND B AND C
+ for i, (sql, e) in enumerate(arr[1:]):
+ if sql < arr[i][0]:
+ expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ break
+ else:
+ # we didn't have to sort but maybe we need to dedup
+ if len(deduped) < len(flattened):
+ expression = result_func(*deduped.values())
+
+ return expression
+
+
+def absorb_and_eliminate(expression):
+ """
+ absorption:
+ A AND (A OR B) -> A
+ A OR (A AND B) -> A
+ A AND (NOT A OR B) -> A AND B
+ A OR (NOT A AND B) -> A OR B
+ elimination:
+ (A AND B) OR (A AND NOT B) -> A
+ (A OR B) AND (A OR NOT B) -> A
+ """
+ if isinstance(expression, exp.Connector):
+ kind = exp.Or if isinstance(expression, exp.And) else exp.And
+
+ for a, b in itertools.permutations(expression.flatten(), 2):
+ if isinstance(a, kind):
+ aa, ab = a.unnest_operands()
+
+ # absorb
+ if is_complement(b, aa):
+ aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ elif is_complement(b, ab):
+ ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
+ a.flatten()
+ ):
+ a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
+ elif isinstance(b, kind):
+ # eliminate
+ rhs = b.unnest_operands()
+ ba, bb = rhs
+
+ if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
+ a.replace(aa)
+ b.replace(aa)
+ elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
+ a.replace(ab)
+ b.replace(ab)
+
+ return expression
+
+
+def simplify_literals(expression):
+ if isinstance(expression, exp.Binary):
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
+
+ while queue:
+ a = queue.popleft()
+
+ for b in queue:
+ result = _simplify_binary(expression, a, b)
+
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
+
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
+ elif isinstance(expression, exp.Neg):
+ this = expression.this
+ if this.is_number:
+ value = this.name
+ if value[0] == "-":
+ return exp.Literal.number(value[1:])
+ return exp.Literal.number(f"-{value}")
+
+ return expression
+
+
+def _simplify_binary(expression, a, b):
+ if isinstance(expression, exp.Is):
+ if isinstance(b, exp.Not):
+ c = b.this
+ not_ = True
+ else:
+ c = b
+ not_ = False
+
+ if c == NULL:
+ if isinstance(a, exp.Literal):
+ return TRUE if not_ else FALSE
+ if a == NULL:
+ return FALSE if not_ else TRUE
+ elif NULL in (a, b):
+ return NULL
+
+ if isinstance(expression, exp.EQ) and a == b:
+ return TRUE
+
+ if a.is_number and b.is_number:
+ a = int(a.name) if a.is_int else Decimal(a.name)
+ b = int(b.name) if b.is_int else Decimal(b.name)
+
+ if isinstance(expression, exp.Add):
+ return exp.Literal.number(a + b)
+ if isinstance(expression, exp.Sub):
+ return exp.Literal.number(a - b)
+ if isinstance(expression, exp.Mul):
+ return exp.Literal.number(a * b)
+ if isinstance(expression, exp.Div):
+ if isinstance(a, int) and isinstance(b, int):
+ return exp.Literal.number(a // b)
+ return exp.Literal.number(a / b)
+
+ boolean = eval_boolean(expression, a, b)
+
+ if boolean:
+ return boolean
+ elif a.is_string and b.is_string:
+ boolean = eval_boolean(expression, a, b)
+
+ if boolean:
+ return boolean
+ elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
+ a, b = extract_date(a), extract_interval(b)
+ if b:
+ if isinstance(expression, exp.Add):
+ return date_literal(a + b)
+ if isinstance(expression, exp.Sub):
+ return date_literal(a - b)
+ elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
+ a, b = extract_interval(a), extract_date(b)
+ # you cannot subtract a date from an interval
+ if a and isinstance(expression, exp.Add):
+ return date_literal(a + b)
+
+ return None
+
+
+def simplify_parens(expression):
+ if (
+ isinstance(expression, exp.Paren)
+ and not isinstance(expression.this, exp.Select)
+ and (
+ not isinstance(expression.parent, (exp.Condition, exp.Binary))
+ or isinstance(expression.this, (exp.Is, exp.Like))
+ or not isinstance(expression.this, exp.Binary)
+ )
+ ):
+ return expression.this
+ return expression
+
+
+def remove_where_true(expression):
+ for where in expression.find_all(exp.Where):
+ if always_true(where.this):
+ where.parent.set("where", None)
+ for join in expression.find_all(exp.Join):
+ if always_true(join.args.get("on")):
+ join.set("kind", "CROSS")
+ join.set("on", None)
+
+
+def always_true(expression):
+ return expression == TRUE or isinstance(expression, exp.Literal)
+
+
+def is_complement(a, b):
+ return isinstance(b, exp.Not) and b.this == a
+
+
+def eval_boolean(expression, a, b):
+ if isinstance(expression, (exp.EQ, exp.Is)):
+ return boolean_literal(a == b)
+ if isinstance(expression, exp.NEQ):
+ return boolean_literal(a != b)
+ if isinstance(expression, exp.GT):
+ return boolean_literal(a > b)
+ if isinstance(expression, exp.GTE):
+ return boolean_literal(a >= b)
+ if isinstance(expression, exp.LT):
+ return boolean_literal(a < b)
+ if isinstance(expression, exp.LTE):
+ return boolean_literal(a <= b)
+ return None
+
+
+def extract_date(cast):
+ if cast.args["to"].this == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(cast.name)
+ return None
+
+
+def extract_interval(interval):
+ try:
+ from dateutil.relativedelta import relativedelta
+ except ModuleNotFoundError:
+ return None
+
+ n = int(interval.name)
+ unit = interval.text("unit").lower()
+
+ if unit == "year":
+ return relativedelta(years=n)
+ if unit == "month":
+ return relativedelta(months=n)
+ if unit == "week":
+ return relativedelta(weeks=n)
+ if unit == "day":
+ return relativedelta(days=n)
+ return None
+
+
+def date_literal(date):
+ return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
+
+
+def boolean_literal(condition):
+ return TRUE if condition else FALSE
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
new file mode 100644
index 0000000..55c81c5
--- /dev/null
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -0,0 +1,220 @@
+import itertools
+
+from sqlglot import exp
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def unnest_subqueries(expression):
+ """
+ Rewrite sqlglot AST to convert some predicates with subqueries into joins.
+
+ Convert the subquery into a group by so it is not a many to many left join.
+ Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
+ Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
+ >>> unnest_subqueries(expression).sql()
+ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
+ AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
+
+ Args:
+ expression (sqlglot.Expression): expression to unnest
+ Returns:
+ sqlglot.Expression: unnested expression
+ """
+ sequence = itertools.count()
+
+ for scope in traverse_scope(expression):
+ select = scope.expression
+ parent = select.parent_select
+ if scope.external_columns:
+ decorrelate(select, parent, scope.external_columns, sequence)
+ else:
+ unnest(select, parent, sequence)
+
+ return expression
+
+
+def unnest(select, parent_select, sequence):
+ predicate = select.find_ancestor(exp.In, exp.Any)
+
+ if not predicate or parent_select is not predicate.parent_select:
+ return
+
+ if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
+ return
+
+ if isinstance(predicate, exp.Any):
+ predicate = predicate.find_ancestor(exp.EQ)
+
+ if not predicate or parent_select is not predicate.parent_select:
+ return
+
+ column = _other_operand(predicate)
+ value = select.selects[0]
+ alias = _alias(sequence)
+
+ on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
+ _replace(predicate, f"NOT {on.right} IS NULL")
+
+ parent_select.join(
+ select.group_by(value.this, copy=False),
+ on=on,
+ join_type="LEFT",
+ join_alias=alias,
+ copy=False,
+ )
+
+
+def decorrelate(select, parent_select, external_columns, sequence):
+ where = select.args.get("where")
+
+ if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
+ return
+
+ table_alias = _alias(sequence)
+ keys = []
+
+ # for all external columns in the where statement,
+ # split out the relevant data to convert it into a join
+ for column in external_columns:
+ if column.find_ancestor(exp.Where) is not where:
+ return
+
+ predicate = column.find_ancestor(exp.Predicate)
+
+ if not predicate or predicate.find_ancestor(exp.Where) is not where:
+ return
+
+ if isinstance(predicate, exp.Binary):
+ key = (
+ predicate.right
+ if any(node is column for node, *_ in predicate.left.walk())
+ else predicate.left
+ )
+ else:
+ return
+
+ keys.append((key, column, predicate))
+
+ if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
+ return
+
+ value = select.selects[0]
+ key_aliases = {}
+ group_by = []
+
+ for key, _, predicate in keys:
+ # if we filter on the value of the subquery, it needs to be unique
+ if key == value.this:
+ key_aliases[key] = value.alias
+ group_by.append(key)
+ else:
+ if key not in key_aliases:
+ key_aliases[key] = _alias(sequence)
+ # all predicates that are equalities must also be in the unique
+ # so that we don't do a many to many join
+ if isinstance(predicate, exp.EQ) and key not in group_by:
+ group_by.append(key)
+
+ parent_predicate = select.find_ancestor(exp.Predicate)
+
+ # if the value of the subquery is not an agg or a key, we need to collect it into an array
+ # so that it can be grouped
+ if not value.find(exp.AggFunc) and value.this not in group_by:
+ select.select(
+ f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
+ )
+
+ # exists queries should not have any selects as it only checks if there are any rows
+ # all selects will be added by the optimizer and only used for join keys
+ if isinstance(parent_predicate, exp.Exists):
+ select.args["expressions"] = []
+
+ for key, alias in key_aliases.items():
+ if key in group_by:
+ # add all keys to the projections of the subquery
+ # so that we can use it as a join key
+ if isinstance(parent_predicate, exp.Exists) or key != value.this:
+ select.select(f"{key} AS {alias}", copy=False)
+ else:
+ select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
+
+ alias = exp.column(value.alias, table_alias)
+ other = _other_operand(parent_predicate)
+
+ if isinstance(parent_predicate, exp.Exists):
+ if value.this in group_by:
+ parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
+ else:
+ parent_predicate = _replace(parent_predicate, "TRUE")
+ elif isinstance(parent_predicate, exp.All):
+ parent_predicate = _replace(
+ parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
+ )
+ elif isinstance(parent_predicate, exp.Any):
+ if value.this in group_by:
+ parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
+ else:
+ parent_predicate = _replace(
+ parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
+ )
+ elif isinstance(parent_predicate, exp.In):
+ if value.this in group_by:
+ parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
+ else:
+ parent_predicate = _replace(
+ parent_predicate,
+ f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
+ )
+ else:
+ select.parent.replace(alias)
+
+ for key, column, predicate in keys:
+ predicate.replace(exp.TRUE)
+ nested = exp.column(key_aliases[key], table_alias)
+
+ if key in group_by:
+ key.replace(nested)
+ parent_predicate = _replace(
+ parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
+ )
+ elif isinstance(predicate, exp.EQ):
+ parent_predicate = _replace(
+ parent_predicate,
+ f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
+ )
+ else:
+ key.replace(exp.to_identifier("_x"))
+ parent_predicate = _replace(
+ parent_predicate,
+ f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
+ )
+
+ parent_select.join(
+ select.group_by(*group_by, copy=False),
+ on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
+ join_type="LEFT",
+ join_alias=table_alias,
+ copy=False,
+ )
+
+
+def _alias(sequence):
+ return f"_u_{next(sequence)}"
+
+
+def _replace(expression, condition):
+ return expression.replace(exp.condition(condition))
+
+
+def _other_operand(expression):
+ if isinstance(expression, exp.In):
+ return expression.this
+
+ if isinstance(expression, exp.Binary):
+ return expression.right if expression.arg_key == "this" else expression.left
+
+ return None