summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py12
-rw-r--r--sqlglot/optimizer/canonicalize.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py5
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/lower_identities.py8
-rw-r--r--sqlglot/optimizer/merge_subqueries.py5
-rw-r--r--sqlglot/optimizer/normalize.py104
-rw-r--r--sqlglot/optimizer/optimize_joins.py2
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py88
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py166
12 files changed, 237 insertions, 163 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index c2d6655..99888c6 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,5 @@
from sqlglot import exp
-from sqlglot.helper import ensure_collection, ensure_list, subclasses
+from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -108,6 +108,7 @@ class TypeAnnotator:
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
+ exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
@@ -116,6 +117,7 @@ class TypeAnnotator:
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@@ -296,9 +298,6 @@ class TypeAnnotator:
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
- if not isinstance(expression, exp.Expression):
- return None
-
if expression.type:
return expression # We've already inferred the expression's type
@@ -311,9 +310,8 @@ class TypeAnnotator:
)
def _annotate_args(self, expression):
- for value in expression.args.values():
- for v in ensure_collection(value):
- self._maybe_annotate(v)
+ for _, value in expression.iter_expressions():
+ self._maybe_annotate(value)
return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index c5c780d..ef929ac 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
- and b.type.this != exp.DataType.Type.DATE
+ and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 8e6a520..e0ddfa2 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -1,7 +1,6 @@
from sqlglot import expressions as exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.optimizer.simplify import simplify
def eliminate_joins(expression):
@@ -179,6 +178,4 @@ def join_condition(join):
for condition in conditions:
extract_condition(condition)
- on = simplify(on)
- remaining_condition = None if on == exp.true() else on
- return source_key, join_key, remaining_condition
+ return source_key, join_key, on
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 6f9db82..a39fe96 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -3,7 +3,6 @@ import itertools
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
-from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
@@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
eliminate_subqueries(expression.this)
return expression
- expression = simplify(expression)
root = build_scope(expression)
# Map of alias->Scope|Table
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
index 1cc76cf..fae1726 100644
--- a/sqlglot/optimizer/lower_identities.py
+++ b/sqlglot/optimizer/lower_identities.py
@@ -1,5 +1,4 @@
from sqlglot import exp
-from sqlglot.helper import ensure_collection
def lower_identities(expression):
@@ -40,13 +39,10 @@ def lower_identities(expression):
lower_identities(expression.right)
traversed |= {"this", "expression"}
- for k, v in expression.args.items():
+ for k, v in expression.iter_expressions():
if k in traversed:
continue
-
- for child in ensure_collection(v):
- if isinstance(child, exp.Expression):
- child.transform(_lower, copy=False)
+ v.transform(_lower, copy=False)
return expression
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 70172f4..c3467b2 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -3,7 +3,6 @@ from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.optimizer.simplify import simplify
def merge_subqueries(expression, leave_tables_isolated=False):
@@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if set(exp.column_table_names(where.this)) <= sources:
from_or_join.on(where.this, copy=False)
- from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
- expression.set("where", simplify(expression.args.get("where")))
+ expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index f16f519..f2df230 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,29 +1,63 @@
+from __future__ import annotations
+
+import logging
+import typing as t
+
from sqlglot import exp
+from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+from sqlglot.optimizer.simplify import flatten, uniq_sort
+
+logger = logging.getLogger("sqlglot")
-def normalize(expression, dnf=False, max_distance=128):
+def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
"""
- Rewrite sqlglot AST into conjunctive normal form.
+ Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
- >>> normalize(expression).sql()
+ >>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Args:
- expression (sqlglot.Expression): expression to normalize
- dnf (bool): rewrite in disjunctive normal form instead
- max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ expression: expression to normalize
+ dnf: rewrite in disjunctive normal form instead.
+ max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
- expression = simplify(expression)
+ cache: t.Dict[int, str] = {}
+
+ for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
+ if isinstance(node, exp.Connector):
+ if normalized(node, dnf=dnf):
+ continue
+
+ distance = normalization_distance(node, dnf=dnf)
+
+ if distance > max_distance:
+ logger.info(
+ f"Skipping normalization because distance {distance} exceeds max {max_distance}"
+ )
+ return expression
+
+ root = node is expression
+ original = node.copy()
+ try:
+ node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ except OptimizeError as e:
+ logger.info(e)
+ node.replace(original)
+ if root:
+ return original
+ return expression
+
+ if root:
+ expression = node
- expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
- return simplify(expression)
+ return expression
def normalized(expression, dnf=False):
@@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
- len(list(expression.find_all(exp.Connector))) + 1
+ sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
@@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf):
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
- return [1]
+ return (1,)
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- return [
+ return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
- ]
+ )
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance):
+def distributive_law(expression, dnf, max_distance, cache=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
- if isinstance(expression.unnest(), exp.Connector):
- if normalization_distance(expression, dnf) > max_distance:
- return expression
+ if normalized(expression, dnf=dnf):
+ return expression
- to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+ distance = normalization_distance(expression, dnf=dnf)
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+ if distance > max_distance:
+ raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
@@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func)
- return _distribute(b, a, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
return expression
-def _distribute(a, b, from_func, to_func):
+def _distribute(a, b, from_func, to_func, cache):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- exp.paren(from_func(c, b.left)),
- exp.paren(from_func(c, b.right)),
+ uniq_sort(flatten(from_func(c, b.left)), cache),
+ uniq_sort(flatten(from_func(c, b.right)), cache),
),
)
else:
- a = to_func(from_func(a, b.left), from_func(a, b.right))
-
- return _simplify(a)
-
+ a = to_func(
+ uniq_sort(flatten(from_func(a, b.left)), cache),
+ uniq_sort(flatten(from_func(a, b.right)), cache),
+ )
-def _simplify(node):
- node = uniq_sort(flatten(node))
- exp.replace_children(node, _simplify)
- return node
+ return a
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index dc5ce44..8589657 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,6 +1,5 @@
from sqlglot import exp
from sqlglot.helper import tsort
-from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
@@ -29,7 +28,6 @@ def optimize_joins(expression):
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
- on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index d9d04be..62eb11e 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@@ -43,6 +44,7 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
+ simplify,
)
@@ -78,7 +80,7 @@ def optimize(
Returns:
sqlglot.Expression: optimized expression
"""
- schema = ensure_schema(schema or sqlglot.schema)
+ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 66b3170..5e40cf3 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
- _expand_using(scope, resolver)
+ using_column_tables = _expand_using(scope, resolver)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver)
+ _expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
+ _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
@@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
- joins = list(scope.expression.find_all(exp.Join))
+ joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
- # Mapping of automatically joined column names to source names
+ # Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
for join in joins:
@@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
)
)
- tables = column_tables.setdefault(identifier, [])
+ # Set all values in the dict to None, because we only care about the key ordering
+ tables = column_tables.setdefault(identifier, {})
if table not in tables:
- tables.append(table)
+ tables[table] = None
if join_table not in tables:
- tables.append(join_table)
+ tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
@@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
scope.replace(column, replacement)
+ return column_tables
-def _expand_group_by(scope, resolver):
- group = scope.expression.args.get("group")
- if not group:
- return
+
+def _expand_alias_refs(scope, resolver):
+ selects = {}
# Replace references to select aliases
def transform(node, *_):
@@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
node.set("table", table)
return node
- selects = {s.alias_or_name: s for s in scope.selects}
-
+ if not selects:
+ for s in scope.selects:
+ selects[s.alias_or_name] = s
select = selects.get(node.name)
+
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
@@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
return node
- group.transform(transform, copy=False)
+ for select in scope.expression.selects:
+ select.transform(transform, copy=False)
+
+ for modifier in ("where", "group"):
+ part = scope.expression.args.get(modifier)
+
+ if part:
+ part.transform(transform, copy=False)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
@@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
+
# Determine whether each reference in the order by clause is to a column or an alias.
- for ordered in scope.find_all(exp.Ordered):
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
+ order = scope.expression.args.get("order")
+
+ if order:
+ for ordered in order.expressions:
+ for column in ordered.find_all(exp.Column):
+ if (
+ not column.table
+ and column.parent is not ordered
+ and column.name in resolver.all_columns
+ ):
+ columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
- for having in scope.find_all(exp.Having):
+ having = scope.expression.args.get("having")
+
+ if having:
for column in having.find_all(exp.Column):
if (
not column.table
@@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
-def _expand_stars(scope, resolver):
+def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
+ coalesced_columns = set()
for expression in scope.selects:
if isinstance(expression, exp.Star):
@@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
if columns and "*" not in columns:
table_id = id(table)
for name in columns:
- if name not in except_columns.get(table_id, set()):
+ if name in using_column_tables and table in using_column_tables[name]:
+ if name in coalesced_columns:
+ continue
+
+ coalesced_columns.add(name)
+ tables = using_column_tables[name]
+ coalesce = [exp.column(name, table=table) for table in tables]
+
+ new_selections.append(
+ exp.alias_(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ )
+ )
+ elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 9c0768c..b582eb0 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -160,7 +160,7 @@ class Scope:
Yields:
exp.Expression: nodes
"""
- for expression, _, _ in self.walk(bfs=bfs):
+ for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index f80484d..1ed3ca2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,10 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify=True)
+GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
@@ -28,18 +27,20 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
+ cache = {}
+
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node)
- node = absorb_and_eliminate(node)
+ node = uniq_sort(node, cache, root)
+ node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
- node = simplify_connectors(node)
- node = remove_compliments(node)
+ node = simplify_connectors(node, root)
+ node = remove_compliments(node, root)
node.parent = expression.parent
- node = simplify_literals(node)
+ node = simplify_literals(node, root)
node = simplify_parens(node)
if root:
expression.replace(node)
@@ -70,7 +71,7 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if isinstance(expression.this, exp.Null):
+ if is_null(expression.this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
@@ -78,11 +79,11 @@ def simplify_not(expression):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
- if isinstance(condition, exp.Null):
+ if is_null(condition):
return exp.null()
if always_true(expression.this):
return exp.false()
- if expression.this == FALSE:
+ if is_false(expression.this):
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
@@ -104,42 +105,42 @@ def flatten(expression):
return expression
-def simplify_connectors(expression):
+def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
- if isinstance(expression, exp.Connector):
- if left == right:
+ if left == right:
+ return left
+ if isinstance(expression, exp.And):
+ if is_false(left) or is_false(right):
+ return exp.false()
+ if is_null(left) or is_null(right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
return left
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return exp.false()
- if NULL in (left, right):
- return exp.null()
- if always_true(left) and always_true(right):
- return exp.true()
- if always_true(left):
- return right
- if always_true(right):
- return left
- return _simplify_comparison(expression, left, right)
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return exp.true()
- if left == FALSE and right == FALSE:
- return exp.false()
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return exp.null()
- if left == FALSE:
- return right
- if right == FALSE:
- return left
- return _simplify_comparison(expression, left, right, or_=True)
- return None
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if is_false(left) and is_false(right):
+ return exp.false()
+ if (
+ (is_null(left) and is_null(right))
+ or (is_null(left) and is_false(right))
+ or (is_false(left) and is_null(right))
+ ):
+ return exp.null()
+ if is_false(left):
+ return right
+ if is_false(right):
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
- return _flat_simplify(expression, _simplify_connectors)
+ if isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_connectors, root)
+ return expression
LT_LTE = (exp.LT, exp.LTE)
@@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
-def remove_compliments(expression):
+def remove_compliments(expression, root=True):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -236,23 +237,23 @@ def remove_compliments(expression):
return expression
-def uniq_sort(expression):
+def uniq_sort(expression, cache=None, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e): e for e in flattened}
+ deduped = {GENERATOR.generate(e, cache): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
- expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ expression = result_func(*(e for _, e in sorted(arr)))
break
else:
# we didn't have to sort but maybe we need to dedup
@@ -262,7 +263,7 @@ def uniq_sort(expression):
return expression
-def absorb_and_eliminate(expression):
+def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
@@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
return expression
-def simplify_literals(expression):
- if isinstance(expression, exp.Binary):
- return _flat_simplify(expression, _simplify_binary)
+def simplify_literals(expression, root=True):
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
c = b
not_ = False
- if c == NULL:
+ if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
- if a == NULL:
+ if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
- elif NULL in (a, b):
+ elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
@@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif a.is_string and b.is_string:
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
@@ -381,7 +382,7 @@ def simplify_parens(expression):
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, (exp.Is, exp.Like))
+ or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
):
@@ -400,13 +401,23 @@ def remove_where_true(expression):
def always_true(expression):
- return expression == TRUE or isinstance(expression, exp.Literal)
+ return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
+ expression, exp.Literal
+ )
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
+def is_false(a: exp.Expression) -> bool:
+ return type(a) is exp.Boolean and not a.this
+
+
+def is_null(a: exp.Expression) -> bool:
+ return type(a) is exp.Null
+
+
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
@@ -466,24 +477,27 @@ def boolean_literal(condition):
return exp.true() if condition else exp.false()
-def _flat_simplify(expression, simplifier):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
+def _flat_simplify(expression, simplifier, root=True):
+ if root or not expression.same_parent:
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
- while queue:
- a = queue.popleft()
+ while queue:
+ a = queue.popleft()
- for b in queue:
- result = simplifier(expression, a, b)
+ for b in queue:
+ result = simplifier(expression, a, b)
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
- if len(operands) < size:
- return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
return expression