From bea2635be022e272ddac349f5e396ec901fc37e5 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 12 Dec 2022 16:42:38 +0100 Subject: Merging upstream version 10.2.6. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/annotate_types.py | 37 ++++-- sqlglot/optimizer/canonicalize.py | 25 +++- sqlglot/optimizer/simplify.py | 235 +++++++++++++++++++++++++----------- 3 files changed, 212 insertions(+), 85 deletions(-) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 191ea52..be17f15 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) - >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" + >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" Args: @@ -41,9 +41,12 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), exp.Alias: lambda self, expr: self._annotate_unary(expr), + exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Literal: lambda self, expr: self._annotate_literal(expr), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), @@ -52,6 +55,9 @@ class TypeAnnotator: expr, exp.DataType.Type.BIGINT ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), @@ -263,10 +269,10 @@ class TypeAnnotator: } # First annotate the current scope's column references for col in scope.columns: - source = scope.sources[col.table] + source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - else: + elif source: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -280,6 +286,7 @@ class TypeAnnotator: return expression # We've already inferred the expression's type annotator = self.annotators.get(expression.__class__) + return ( annotator(self, expression) if annotator @@ -295,18 +302,23 @@ class TypeAnnotator: def _maybe_coerce(self, type1, type2): # We propagate the NULL / UNKNOWN types upwards if found + if isinstance(type1, exp.DataType): + type1 = type1.this + if isinstance(type2, exp.DataType): + type2 = type2.this + if exp.DataType.Type.NULL in (type1, type2): return exp.DataType.Type.NULL if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to[type1] else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - left_type = expression.left.type - right_type = expression.right.type + left_type = expression.left.type.this + right_type = expression.right.type.this if isinstance(expression, (exp.And, exp.Or)): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -348,7 +360,7 @@ class TypeAnnotator: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args): + def _annotate_by_args(self, expression, *args, promote=False): self._annotate_args(expression) expressions = [] for arg in args: @@ -360,4 +372,11 @@ class TypeAnnotator: last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) expression.type = last_datatype or exp.DataType.Type.UNKNOWN + + if promote: + if expression.type.this in exp.DataType.INTEGER_TYPES: + expression.type = exp.DataType.Type.BIGINT + elif expression.type.this in exp.DataType.FLOAT_TYPES: + expression.type = exp.DataType.Type.DOUBLE + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 9b3d98a..33529a5 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression: The expression to canonicalize. """ exp.replace_children(expression, canonicalize) + expression = add_text_to_concat(expression) expression = coerce_type(expression) + expression = remove_redundant_casts(expression) + return expression def add_text_to_concat(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: + if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: node = exp.Concat(this=node.this, expression=node.expression) return node @@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression: elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) elif isinstance(node, exp.Extract): - if node.expression.type not in exp.DataType.TEMPORAL_TYPES: + if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: _replace_cast(node.expression, "datetime") return node +def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Cast) + and expression.to.type + and expression.this.type + and expression.to.type.this == expression.this.type.this + ): + return expression.this + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): - if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: + if ( + a.type + and a.type.this == exp.DataType.Type.DATE + and b.type + and b.type.this != exp.DataType.Type.DATE + ): _replace_cast(b, "date") diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c432c59..c0719f2 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -7,7 +7,7 @@ from decimal import Decimal from sqlglot import exp from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.generator import Generator -from sqlglot.helper import while_changing +from sqlglot.helper import first, while_changing GENERATOR = Generator(normalize=True, identify=True) @@ -30,6 +30,7 @@ def simplify(expression): def _simplify(expression, root=True): node = expression + node = rewrite_between(node) node = uniq_sort(node) node = absorb_and_eliminate(node) exp.replace_children(node, lambda e: _simplify(e, False)) @@ -49,6 +50,19 @@ def simplify(expression): return expression +def rewrite_between(expression: exp.Expression) -> exp.Expression: + """Rewrite x between y and z to x >= y AND x <= z. + + This is done because comparison simplification is only done on lt/lte/gt/gte. + """ + if isinstance(expression, exp.Between): + return exp.and_( + exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), + exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), + ) + return expression + + def simplify_not(expression): """ Demorgan's Law @@ -57,7 +71,7 @@ def simplify_not(expression): """ if isinstance(expression, exp.Not): if isinstance(expression.this, exp.Null): - return NULL + return exp.null() if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): @@ -65,11 +79,11 @@ def simplify_not(expression): if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Null): - return NULL + return exp.null() if always_true(expression.this): - return FALSE + return exp.false() if expression.this == FALSE: - return TRUE + return exp.true() if isinstance(expression.this, exp.Not): # double negation # NOT NOT x -> x @@ -91,40 +105,119 @@ def flatten(expression): def simplify_connectors(expression): - if isinstance(expression, exp.Connector): - left = expression.left - right = expression.right - - if left == right: - return left - - if isinstance(expression, exp.And): - if FALSE in (left, right): - return FALSE - if NULL in (left, right): - return NULL - if always_true(left) and always_true(right): - return TRUE - if always_true(left): - return right - if always_true(right): - return left - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return TRUE - if left == FALSE and right == FALSE: - return FALSE - if ( - (left == NULL and right == NULL) - or (left == NULL and right == FALSE) - or (left == FALSE and right == NULL) - ): - return NULL - if left == FALSE: - return right - if right == FALSE: + def _simplify_connectors(expression, left, right): + if isinstance(expression, exp.Connector): + if left == right: return left - return expression + if isinstance(expression, exp.And): + if FALSE in (left, right): + return exp.false() + if NULL in (left, right): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): + return left + return _simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if left == FALSE and right == FALSE: + return exp.false() + if ( + (left == NULL and right == NULL) + or (left == NULL and right == FALSE) + or (left == FALSE and right == NULL) + ): + return exp.null() + if left == FALSE: + return right + if right == FALSE: + return left + return _simplify_comparison(expression, left, right, or_=True) + return None + + return _flat_simplify(expression, _simplify_connectors) + + +LT_LTE = (exp.LT, exp.LTE) +GT_GTE = (exp.GT, exp.GTE) + +COMPARISONS = ( + *LT_LTE, + *GT_GTE, + exp.EQ, + exp.NEQ, +) + +INVERSE_COMPARISONS = { + exp.LT: exp.GT, + exp.GT: exp.LT, + exp.LTE: exp.GTE, + exp.GTE: exp.LTE, +} + + +def _simplify_comparison(expression, left, right, or_=False): + if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): + ll, lr = left.args.values() + rl, rr = right.args.values() + + largs = {ll, lr} + rargs = {rl, rr} + + matching = largs & rargs + columns = {m for m in matching if isinstance(m, exp.Column)} + + if matching and columns: + try: + l = first(largs - columns) + r = first(rargs - columns) + except StopIteration: + return expression + + # make sure the comparison is always of the form x > 1 instead of 1 < x + if left.__class__ in INVERSE_COMPARISONS and l == ll: + left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) + if right.__class__ in INVERSE_COMPARISONS and r == rl: + right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) + + if l.is_number and r.is_number: + l = float(l.name) + r = float(r.name) + elif l.is_string and r.is_string: + l = l.name + r = r.name + else: + return None + + for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): + if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): + return left if (av > bv if or_ else av <= bv) else right + if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): + return left if (av < bv if or_ else av >= bv) else right + + # we can't ever shortcut to true because the column could be null + if isinstance(a, exp.LT) and isinstance(b, GT_GTE): + if not or_ and av <= bv: + return exp.false() + elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): + if not or_ and av >= bv: + return exp.false() + elif isinstance(a, exp.EQ): + if isinstance(b, exp.LT): + return exp.false() if av >= bv else a + if isinstance(b, exp.LTE): + return exp.false() if av > bv else a + if isinstance(b, exp.GT): + return exp.false() if av <= bv else a + if isinstance(b, exp.GTE): + return exp.false() if av < bv else a + if isinstance(b, exp.NEQ): + return exp.false() if av == bv else a + return None def remove_compliments(expression): @@ -135,7 +228,7 @@ def remove_compliments(expression): A OR NOT A -> TRUE """ if isinstance(expression, exp.Connector): - compliment = FALSE if isinstance(expression, exp.And) else TRUE + compliment = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): if is_complement(a, b): @@ -211,27 +304,7 @@ def absorb_and_eliminate(expression): def simplify_literals(expression): if isinstance(expression, exp.Binary): - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) - - while queue: - a = queue.popleft() - - for b in queue: - result = _simplify_binary(expression, a, b) - - if result: - queue.remove(b) - queue.append(result) - break - else: - operands.append(a) - - if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) + return _flat_simplify(expression, _simplify_binary) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b): if c == NULL: if isinstance(a, exp.Literal): - return TRUE if not_ else FALSE + return exp.true() if not_ else exp.false() if a == NULL: - return FALSE if not_ else TRUE - elif isinstance(expression, exp.NullSafeEQ): - if a == b: - return TRUE - elif isinstance(expression, exp.NullSafeNEQ): - if a == b: - return FALSE + return exp.false() if not_ else exp.true() + elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): + return None elif NULL in (a, b): - return NULL - - if isinstance(expression, exp.EQ) and a == b: - return TRUE + return exp.null() if a.is_number and b.is_number: a = int(a.name) if a.is_int else Decimal(a.name) @@ -388,4 +454,27 @@ def date_literal(date): def boolean_literal(condition): - return TRUE if condition else FALSE + return exp.true() if condition else exp.false() + + +def _flat_simplify(expression, simplifier): + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = simplifier(expression, a, b) + + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + return expression -- cgit v1.2.3