summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/normalize.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-16 11:37:39 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-16 11:37:39 +0000
commitf10d022e11dcd1015db1a74ce9f4198ebdcb7f40 (patch)
treeac7bdc1d214a0f97f991cff14e933f4895ee68e1 /sqlglot/optimizer/normalize.py
parentReleasing progress-linux version 18.11.6-1. (diff)
downloadsqlglot-f10d022e11dcd1015db1a74ce9f4198ebdcb7f40.tar.xz
sqlglot-f10d022e11dcd1015db1a74ce9f4198ebdcb7f40.zip
Merging upstream version 18.13.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r--sqlglot/optimizer/normalize.py37
1 files changed, 29 insertions, 8 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 1db094e..8d82b2d 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -6,6 +6,7 @@ from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
+from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
return expression
-def normalized(expression, dnf=False):
- ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
+ """
+ Checks whether a given expression is in a normal form of interest.
- return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
+ Example:
+ >>> from sqlglot import parse_one
+ >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
+ True
+ >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
+ True
+ >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
+ False
+ Args:
+ expression: The expression to check if it's normalized.
+ dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
+ Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
+ """
+ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+ return not any(
+ connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
+ )
-def normalization_distance(expression, dnf=False):
+
+def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
"""
- The difference in the number of predicates between the current expression and the normalized form.
+ The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
@@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
4
Args:
- expression (sqlglot.Expression): expression to compute distance
- dnf (bool): compute to dnf distance instead
+ expression: The expression to compute the normalization distance for.
+ dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
+ Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
+
Returns:
- int: difference
+ The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1