From 3742f86d166160ca3843872ebecb6f30c51f6085 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 14 Aug 2023 12:12:19 +0200 Subject: Merging upstream version 17.12.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/simplify.py | 108 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index e247f58..e550603 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -54,11 +54,17 @@ def simplify(expression): def _simplify(expression, root=True): if expression.meta.get(FINAL): return expression + + # Pre-order transformations node = expression node = rewrite_between(node) node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) + node = simplify_concat(node) + exp.replace_children(node, lambda e: _simplify(e, False)) + + # Post-order transformations node = simplify_not(node) node = flatten(node) node = simplify_connectors(node, root) @@ -66,8 +72,11 @@ def simplify(expression): node.parent = expression.parent node = simplify_literals(node, root) node = simplify_parens(node) + node = simplify_coalesce(node) + if root: expression.replace(node) + return node expression = while_changing(expression, _simplify) @@ -184,6 +193,7 @@ COMPARISONS = ( *GT_GTE, exp.EQ, exp.NEQ, + exp.Is, ) INVERSE_COMPARISONS = { @@ -430,6 +440,103 @@ def simplify_parens(expression): return expression +CONSTANTS = ( + exp.Literal, + exp.Boolean, + exp.Null, +) + + +def simplify_coalesce(expression): + # COALESCE(x) -> x + if ( + isinstance(expression, exp.Coalesce) + and not expression.expressions + # COALESCE is also used as a Spark partitioning hint + and not isinstance(expression.parent, exp.Hint) + ): + return expression.this + + if not isinstance(expression, COMPARISONS): + return expression + + if isinstance(expression.left, exp.Coalesce): + coalesce = expression.left + other = expression.right + elif isinstance(expression.right, exp.Coalesce): + coalesce = expression.right + other = expression.left + else: + return expression + + # This transformation is valid for non-constants, + # but it really only does anything if they are both constants. + if not isinstance(other, CONSTANTS): + return expression + + # Find the first constant arg + for arg_index, arg in enumerate(coalesce.expressions): + if isinstance(arg, CONSTANTS): + break + else: + return expression + + coalesce.set("expressions", coalesce.expressions[:arg_index]) + + # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, + # since we already remove COALESCE at the top of this function. + coalesce = coalesce if coalesce.expressions else coalesce.this + + # This expression is more complex than when we started, but it will get simplified further + return exp.or_( + exp.and_( + coalesce.is_(exp.null()).not_(copy=False), + expression.copy(), + copy=False, + ), + exp.and_( + coalesce.is_(exp.null()), + type(expression)(this=arg.copy(), expression=other.copy()), + copy=False, + ), + copy=False, + ) + + +CONCATS = (exp.Concat, exp.DPipe) +SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) + + +def simplify_concat(expression): + """Reduces all groups that contain string literals by concatenating them.""" + if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): + return expression + + new_args = [] + for is_string_group, group in itertools.groupby( + expression.expressions or expression.flatten(), lambda e: e.is_string + ): + if is_string_group: + new_args.append(exp.Literal.string("".join(string.name for string in group))) + else: + new_args.extend(group) + + # Ensures we preserve the right concat type, i.e. whether it's "safe" or not + concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) + + +# CROSS joins result in an empty table if the right table is empty. +# So we can only simplify certain types of joins to CROSS. +# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x +JOINS = { + ("", ""), + ("", "INNER"), + ("RIGHT", ""), + ("RIGHT", "OUTER"), +} + + def remove_where_true(expression): for where in expression.find_all(exp.Where): if always_true(where.this): @@ -439,6 +546,7 @@ def remove_where_true(expression): always_true(join.args.get("on")) and not join.args.get("using") and not join.args.get("method") + and (join.side, join.kind) in JOINS ): join.set("on", None) join.set("side", None) -- cgit v1.2.3