diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 1db094e..8d82b2d 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -6,6 +6,7 @@ from sqlglot import exp from sqlglot.errors import OptimizeError from sqlglot.generator import cached_generator from sqlglot.helper import while_changing +from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = return expression -def normalized(expression, dnf=False): - ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) +def normalized(expression: exp.Expression, dnf: bool = False) -> bool: + """ + Checks whether a given expression is in a normal form of interest. - return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) + Example: + >>> from sqlglot import parse_one + >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) + True + >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default + True + >>> normalized(parse_one("a AND (b OR c)"), dnf=True) + False + Args: + expression: The expression to check if it's normalized. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + """ + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + return not any( + connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) + ) -def normalization_distance(expression, dnf=False): + +def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: """ - The difference in the number of predicates between the current expression and the normalized form. + The difference in the number of predicates between a given expression and its normalized form. This is used as an estimate of the cost of the conversion which is exponential in complexity. @@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False): 4 Args: - expression (sqlglot.Expression): expression to compute distance - dnf (bool): compute to dnf distance instead + expression: The expression to compute the normalization distance for. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + Returns: - int: difference + The normalization distance. """ return sum(_predicate_lengths(expression, dnf)) - ( sum(1 for _ in expression.find_all(exp.Connector)) + 1 |