summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/simplify.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r--sqlglot/optimizer/simplify.py108
1 files changed, 108 insertions, 0 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index e247f58..e550603 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -54,11 +54,17 @@ def simplify(expression):
def _simplify(expression, root=True):
if expression.meta.get(FINAL):
return expression
+
+ # Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
+ node = simplify_concat(node)
+
exp.replace_children(node, lambda e: _simplify(e, False))
+
+ # Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
@@ -66,8 +72,11 @@ def simplify(expression):
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
+ node = simplify_coalesce(node)
+
if root:
expression.replace(node)
+
return node
expression = while_changing(expression, _simplify)
@@ -184,6 +193,7 @@ COMPARISONS = (
*GT_GTE,
exp.EQ,
exp.NEQ,
+ exp.Is,
)
INVERSE_COMPARISONS = {
@@ -430,6 +440,103 @@ def simplify_parens(expression):
return expression
+CONSTANTS = (
+ exp.Literal,
+ exp.Boolean,
+ exp.Null,
+)
+
+
+def simplify_coalesce(expression):
+ # COALESCE(x) -> x
+ if (
+ isinstance(expression, exp.Coalesce)
+ and not expression.expressions
+ # COALESCE is also used as a Spark partitioning hint
+ and not isinstance(expression.parent, exp.Hint)
+ ):
+ return expression.this
+
+ if not isinstance(expression, COMPARISONS):
+ return expression
+
+ if isinstance(expression.left, exp.Coalesce):
+ coalesce = expression.left
+ other = expression.right
+ elif isinstance(expression.right, exp.Coalesce):
+ coalesce = expression.right
+ other = expression.left
+ else:
+ return expression
+
+ # This transformation is valid for non-constants,
+ # but it really only does anything if they are both constants.
+ if not isinstance(other, CONSTANTS):
+ return expression
+
+ # Find the first constant arg
+ for arg_index, arg in enumerate(coalesce.expressions):
+ if isinstance(arg, CONSTANTS):
+ break
+ else:
+ return expression
+
+ coalesce.set("expressions", coalesce.expressions[:arg_index])
+
+ # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
+ # since we already remove COALESCE at the top of this function.
+ coalesce = coalesce if coalesce.expressions else coalesce.this
+
+ # This expression is more complex than when we started, but it will get simplified further
+ return exp.or_(
+ exp.and_(
+ coalesce.is_(exp.null()).not_(copy=False),
+ expression.copy(),
+ copy=False,
+ ),
+ exp.and_(
+ coalesce.is_(exp.null()),
+ type(expression)(this=arg.copy(), expression=other.copy()),
+ copy=False,
+ ),
+ copy=False,
+ )
+
+
+CONCATS = (exp.Concat, exp.DPipe)
+SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
+
+
+def simplify_concat(expression):
+ """Reduces all groups that contain string literals by concatenating them."""
+ if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
+ return expression
+
+ new_args = []
+ for is_string_group, group in itertools.groupby(
+ expression.expressions or expression.flatten(), lambda e: e.is_string
+ ):
+ if is_string_group:
+ new_args.append(exp.Literal.string("".join(string.name for string in group)))
+ else:
+ new_args.extend(group)
+
+ # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
+ concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+ return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
+
+
+# CROSS joins result in an empty table if the right table is empty.
+# So we can only simplify certain types of joins to CROSS.
+# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
+JOINS = {
+ ("", ""),
+ ("", "INNER"),
+ ("RIGHT", ""),
+ ("RIGHT", "OUTER"),
+}
+
+
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
@@ -439,6 +546,7 @@ def remove_where_true(expression):
always_true(join.args.get("on"))
and not join.args.get("using")
and not join.args.get("method")
+ and (join.side, join.kind) in JOINS
):
join.set("on", None)
join.set("side", None)