sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4from collections import deque 5from decimal import Decimal 6 7from sqlglot import exp 8from sqlglot.generator import cached_generator 9from sqlglot.helper import first, while_changing 10 11# Final means that an expression should not be simplified 12FINAL = "final" 13 14 15def simplify(expression): 16 """ 17 Rewrite sqlglot AST to simplify expressions. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 22 >>> simplify(expression).sql() 23 'TRUE' 24 25 Args: 26 expression (sqlglot.Expression): expression to simplify 27 Returns: 28 sqlglot.Expression: simplified expression 29 """ 30 31 generate = cached_generator() 32 33 # group by expressions cannot be simplified, for example 34 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 35 # the projection must exactly match the group by key 36 for group in expression.find_all(exp.Group): 37 select = group.parent 38 groups = set(group.expressions) 39 group.meta[FINAL] = True 40 41 for e in select.selects: 42 for node, *_ in e.walk(): 43 if node in groups: 44 e.meta[FINAL] = True 45 break 46 47 having = select.args.get("having") 48 if having: 49 for node, *_ in having.walk(): 50 if node in groups: 51 having.meta[FINAL] = True 52 break 53 54 def _simplify(expression, root=True): 55 if expression.meta.get(FINAL): 56 return expression 57 58 # Pre-order transformations 59 node = expression 60 node = rewrite_between(node) 61 node = uniq_sort(node, generate, root) 62 node = absorb_and_eliminate(node, root) 63 node = simplify_concat(node) 64 65 exp.replace_children(node, lambda e: _simplify(e, False)) 66 67 # Post-order transformations 68 node = simplify_not(node) 69 node = flatten(node) 70 node = simplify_connectors(node, root) 71 node = remove_compliments(node, root) 72 node = simplify_coalesce(node) 73 node.parent = expression.parent 74 node = simplify_literals(node, root) 75 node = simplify_parens(node) 76 77 if root: 78 expression.replace(node) 79 80 return node 81 82 expression = while_changing(expression, _simplify) 83 remove_where_true(expression) 84 return expression 85 86 87def rewrite_between(expression: exp.Expression) -> exp.Expression: 88 """Rewrite x between y and z to x >= y AND x <= z. 89 90 This is done because comparison simplification is only done on lt/lte/gt/gte. 91 """ 92 if isinstance(expression, exp.Between): 93 return exp.and_( 94 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 95 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 96 copy=False, 97 ) 98 return expression 99 100 101def simplify_not(expression): 102 """ 103 Demorgan's Law 104 NOT (x OR y) -> NOT x AND NOT y 105 NOT (x AND y) -> NOT x OR NOT y 106 """ 107 if isinstance(expression, exp.Not): 108 if is_null(expression.this): 109 return exp.null() 110 if isinstance(expression.this, exp.Paren): 111 condition = expression.this.unnest() 112 if isinstance(condition, exp.And): 113 return exp.or_( 114 exp.not_(condition.left, copy=False), 115 exp.not_(condition.right, copy=False), 116 copy=False, 117 ) 118 if isinstance(condition, exp.Or): 119 return exp.and_( 120 exp.not_(condition.left, copy=False), 121 exp.not_(condition.right, copy=False), 122 copy=False, 123 ) 124 if is_null(condition): 125 return exp.null() 126 if always_true(expression.this): 127 return exp.false() 128 if is_false(expression.this): 129 return exp.true() 130 if isinstance(expression.this, exp.Not): 131 # double negation 132 # NOT NOT x -> x 133 return expression.this.this 134 return expression 135 136 137def flatten(expression): 138 """ 139 A AND (B AND C) -> A AND B AND C 140 A OR (B OR C) -> A OR B OR C 141 """ 142 if isinstance(expression, exp.Connector): 143 for node in expression.args.values(): 144 child = node.unnest() 145 if isinstance(child, expression.__class__): 146 node.replace(child) 147 return expression 148 149 150def simplify_connectors(expression, root=True): 151 def _simplify_connectors(expression, left, right): 152 if left == right: 153 return left 154 if isinstance(expression, exp.And): 155 if is_false(left) or is_false(right): 156 return exp.false() 157 if is_null(left) or is_null(right): 158 return exp.null() 159 if always_true(left) and always_true(right): 160 return exp.true() 161 if always_true(left): 162 return right 163 if always_true(right): 164 return left 165 return _simplify_comparison(expression, left, right) 166 elif isinstance(expression, exp.Or): 167 if always_true(left) or always_true(right): 168 return exp.true() 169 if is_false(left) and is_false(right): 170 return exp.false() 171 if ( 172 (is_null(left) and is_null(right)) 173 or (is_null(left) and is_false(right)) 174 or (is_false(left) and is_null(right)) 175 ): 176 return exp.null() 177 if is_false(left): 178 return right 179 if is_false(right): 180 return left 181 return _simplify_comparison(expression, left, right, or_=True) 182 183 if isinstance(expression, exp.Connector): 184 return _flat_simplify(expression, _simplify_connectors, root) 185 return expression 186 187 188LT_LTE = (exp.LT, exp.LTE) 189GT_GTE = (exp.GT, exp.GTE) 190 191COMPARISONS = ( 192 *LT_LTE, 193 *GT_GTE, 194 exp.EQ, 195 exp.NEQ, 196 exp.Is, 197) 198 199INVERSE_COMPARISONS = { 200 exp.LT: exp.GT, 201 exp.GT: exp.LT, 202 exp.LTE: exp.GTE, 203 exp.GTE: exp.LTE, 204} 205 206 207def _simplify_comparison(expression, left, right, or_=False): 208 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 209 ll, lr = left.args.values() 210 rl, rr = right.args.values() 211 212 largs = {ll, lr} 213 rargs = {rl, rr} 214 215 matching = largs & rargs 216 columns = {m for m in matching if isinstance(m, exp.Column)} 217 218 if matching and columns: 219 try: 220 l = first(largs - columns) 221 r = first(rargs - columns) 222 except StopIteration: 223 return expression 224 225 # make sure the comparison is always of the form x > 1 instead of 1 < x 226 if left.__class__ in INVERSE_COMPARISONS and l == ll: 227 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 228 if right.__class__ in INVERSE_COMPARISONS and r == rl: 229 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 230 231 if l.is_number and r.is_number: 232 l = float(l.name) 233 r = float(r.name) 234 elif l.is_string and r.is_string: 235 l = l.name 236 r = r.name 237 else: 238 return None 239 240 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 241 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 242 return left if (av > bv if or_ else av <= bv) else right 243 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 244 return left if (av < bv if or_ else av >= bv) else right 245 246 # we can't ever shortcut to true because the column could be null 247 if not or_: 248 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 249 if av <= bv: 250 return exp.false() 251 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 252 if av >= bv: 253 return exp.false() 254 elif isinstance(a, exp.EQ): 255 if isinstance(b, exp.LT): 256 return exp.false() if av >= bv else a 257 if isinstance(b, exp.LTE): 258 return exp.false() if av > bv else a 259 if isinstance(b, exp.GT): 260 return exp.false() if av <= bv else a 261 if isinstance(b, exp.GTE): 262 return exp.false() if av < bv else a 263 if isinstance(b, exp.NEQ): 264 return exp.false() if av == bv else a 265 return None 266 267 268def remove_compliments(expression, root=True): 269 """ 270 Removing compliments. 271 272 A AND NOT A -> FALSE 273 A OR NOT A -> TRUE 274 """ 275 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 276 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 277 278 for a, b in itertools.permutations(expression.flatten(), 2): 279 if is_complement(a, b): 280 return compliment 281 return expression 282 283 284def uniq_sort(expression, generate, root=True): 285 """ 286 Uniq and sort a connector. 287 288 C AND A AND B AND B -> A AND B AND C 289 """ 290 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 291 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 292 flattened = tuple(expression.flatten()) 293 deduped = {generate(e): e for e in flattened} 294 arr = tuple(deduped.items()) 295 296 # check if the operands are already sorted, if not sort them 297 # A AND C AND B -> A AND B AND C 298 for i, (sql, e) in enumerate(arr[1:]): 299 if sql < arr[i][0]: 300 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 301 break 302 else: 303 # we didn't have to sort but maybe we need to dedup 304 if len(deduped) < len(flattened): 305 expression = result_func(*deduped.values(), copy=False) 306 307 return expression 308 309 310def absorb_and_eliminate(expression, root=True): 311 """ 312 absorption: 313 A AND (A OR B) -> A 314 A OR (A AND B) -> A 315 A AND (NOT A OR B) -> A AND B 316 A OR (NOT A AND B) -> A OR B 317 elimination: 318 (A AND B) OR (A AND NOT B) -> A 319 (A OR B) AND (A OR NOT B) -> A 320 """ 321 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 322 kind = exp.Or if isinstance(expression, exp.And) else exp.And 323 324 for a, b in itertools.permutations(expression.flatten(), 2): 325 if isinstance(a, kind): 326 aa, ab = a.unnest_operands() 327 328 # absorb 329 if is_complement(b, aa): 330 aa.replace(exp.true() if kind == exp.And else exp.false()) 331 elif is_complement(b, ab): 332 ab.replace(exp.true() if kind == exp.And else exp.false()) 333 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 334 a.replace(exp.false() if kind == exp.And else exp.true()) 335 elif isinstance(b, kind): 336 # eliminate 337 rhs = b.unnest_operands() 338 ba, bb = rhs 339 340 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 341 a.replace(aa) 342 b.replace(aa) 343 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 344 a.replace(ab) 345 b.replace(ab) 346 347 return expression 348 349 350def simplify_literals(expression, root=True): 351 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 352 return _flat_simplify(expression, _simplify_binary, root) 353 354 if isinstance(expression, exp.Neg): 355 this = expression.this 356 if this.is_number: 357 value = this.name 358 if value[0] == "-": 359 return exp.Literal.number(value[1:]) 360 return exp.Literal.number(f"-{value}") 361 362 return expression 363 364 365def _simplify_binary(expression, a, b): 366 if isinstance(expression, exp.Is): 367 if isinstance(b, exp.Not): 368 c = b.this 369 not_ = True 370 else: 371 c = b 372 not_ = False 373 374 if is_null(c): 375 if isinstance(a, exp.Literal): 376 return exp.true() if not_ else exp.false() 377 if is_null(a): 378 return exp.false() if not_ else exp.true() 379 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 380 return None 381 elif is_null(a) or is_null(b): 382 return exp.null() 383 384 if a.is_number and b.is_number: 385 a = int(a.name) if a.is_int else Decimal(a.name) 386 b = int(b.name) if b.is_int else Decimal(b.name) 387 388 if isinstance(expression, exp.Add): 389 return exp.Literal.number(a + b) 390 if isinstance(expression, exp.Sub): 391 return exp.Literal.number(a - b) 392 if isinstance(expression, exp.Mul): 393 return exp.Literal.number(a * b) 394 if isinstance(expression, exp.Div): 395 # engines have differing int div behavior so intdiv is not safe 396 if isinstance(a, int) and isinstance(b, int): 397 return None 398 return exp.Literal.number(a / b) 399 400 boolean = eval_boolean(expression, a, b) 401 402 if boolean: 403 return boolean 404 elif a.is_string and b.is_string: 405 boolean = eval_boolean(expression, a.this, b.this) 406 407 if boolean: 408 return boolean 409 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 410 a, b = extract_date(a), extract_interval(b) 411 if a and b: 412 if isinstance(expression, exp.Add): 413 return date_literal(a + b) 414 if isinstance(expression, exp.Sub): 415 return date_literal(a - b) 416 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 417 a, b = extract_interval(a), extract_date(b) 418 # you cannot subtract a date from an interval 419 if a and b and isinstance(expression, exp.Add): 420 return date_literal(a + b) 421 422 return None 423 424 425def simplify_parens(expression): 426 if not isinstance(expression, exp.Paren): 427 return expression 428 429 this = expression.this 430 parent = expression.parent 431 432 if not isinstance(this, exp.Select) and ( 433 not isinstance(parent, (exp.Condition, exp.Binary)) 434 or isinstance(parent, exp.Paren) 435 or not isinstance(this, exp.Binary) 436 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 437 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 438 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 439 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 440 ): 441 return this 442 return expression 443 444 445CONSTANTS = ( 446 exp.Literal, 447 exp.Boolean, 448 exp.Null, 449) 450 451 452def simplify_coalesce(expression): 453 # COALESCE(x) -> x 454 if ( 455 isinstance(expression, exp.Coalesce) 456 and not expression.expressions 457 # COALESCE is also used as a Spark partitioning hint 458 and not isinstance(expression.parent, exp.Hint) 459 ): 460 return expression.this 461 462 if not isinstance(expression, COMPARISONS): 463 return expression 464 465 if isinstance(expression.left, exp.Coalesce): 466 coalesce = expression.left 467 other = expression.right 468 elif isinstance(expression.right, exp.Coalesce): 469 coalesce = expression.right 470 other = expression.left 471 else: 472 return expression 473 474 # This transformation is valid for non-constants, 475 # but it really only does anything if they are both constants. 476 if not isinstance(other, CONSTANTS): 477 return expression 478 479 # Find the first constant arg 480 for arg_index, arg in enumerate(coalesce.expressions): 481 if isinstance(arg, CONSTANTS): 482 break 483 else: 484 return expression 485 486 coalesce.set("expressions", coalesce.expressions[:arg_index]) 487 488 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 489 # since we already remove COALESCE at the top of this function. 490 coalesce = coalesce if coalesce.expressions else coalesce.this 491 492 # This expression is more complex than when we started, but it will get simplified further 493 return exp.paren( 494 exp.or_( 495 exp.and_( 496 coalesce.is_(exp.null()).not_(copy=False), 497 expression.copy(), 498 copy=False, 499 ), 500 exp.and_( 501 coalesce.is_(exp.null()), 502 type(expression)(this=arg.copy(), expression=other.copy()), 503 copy=False, 504 ), 505 copy=False, 506 ) 507 ) 508 509 510CONCATS = (exp.Concat, exp.DPipe) 511SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 512 513 514def simplify_concat(expression): 515 """Reduces all groups that contain string literals by concatenating them.""" 516 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 517 return expression 518 519 new_args = [] 520 for is_string_group, group in itertools.groupby( 521 expression.expressions or expression.flatten(), lambda e: e.is_string 522 ): 523 if is_string_group: 524 new_args.append(exp.Literal.string("".join(string.name for string in group))) 525 else: 526 new_args.extend(group) 527 528 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 529 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 530 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 531 532 533# CROSS joins result in an empty table if the right table is empty. 534# So we can only simplify certain types of joins to CROSS. 535# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 536JOINS = { 537 ("", ""), 538 ("", "INNER"), 539 ("RIGHT", ""), 540 ("RIGHT", "OUTER"), 541} 542 543 544def remove_where_true(expression): 545 for where in expression.find_all(exp.Where): 546 if always_true(where.this): 547 where.parent.set("where", None) 548 for join in expression.find_all(exp.Join): 549 if ( 550 always_true(join.args.get("on")) 551 and not join.args.get("using") 552 and not join.args.get("method") 553 and (join.side, join.kind) in JOINS 554 ): 555 join.set("on", None) 556 join.set("side", None) 557 join.set("kind", "CROSS") 558 559 560def always_true(expression): 561 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 562 expression, exp.Literal 563 ) 564 565 566def is_complement(a, b): 567 return isinstance(b, exp.Not) and b.this == a 568 569 570def is_false(a: exp.Expression) -> bool: 571 return type(a) is exp.Boolean and not a.this 572 573 574def is_null(a: exp.Expression) -> bool: 575 return type(a) is exp.Null 576 577 578def eval_boolean(expression, a, b): 579 if isinstance(expression, (exp.EQ, exp.Is)): 580 return boolean_literal(a == b) 581 if isinstance(expression, exp.NEQ): 582 return boolean_literal(a != b) 583 if isinstance(expression, exp.GT): 584 return boolean_literal(a > b) 585 if isinstance(expression, exp.GTE): 586 return boolean_literal(a >= b) 587 if isinstance(expression, exp.LT): 588 return boolean_literal(a < b) 589 if isinstance(expression, exp.LTE): 590 return boolean_literal(a <= b) 591 return None 592 593 594def extract_date(cast): 595 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 596 # so in that case we can't extract the date. 597 try: 598 if cast.args["to"].this == exp.DataType.Type.DATE: 599 return datetime.date.fromisoformat(cast.name) 600 if cast.args["to"].this == exp.DataType.Type.DATETIME: 601 return datetime.datetime.fromisoformat(cast.name) 602 except ValueError: 603 return None 604 605 606def extract_interval(interval): 607 try: 608 from dateutil.relativedelta import relativedelta # type: ignore 609 except ModuleNotFoundError: 610 return None 611 612 n = int(interval.name) 613 unit = interval.text("unit").lower() 614 615 if unit == "year": 616 return relativedelta(years=n) 617 if unit == "month": 618 return relativedelta(months=n) 619 if unit == "week": 620 return relativedelta(weeks=n) 621 if unit == "day": 622 return relativedelta(days=n) 623 return None 624 625 626def date_literal(date): 627 return exp.cast( 628 exp.Literal.string(date), 629 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 630 ) 631 632 633def boolean_literal(condition): 634 return exp.true() if condition else exp.false() 635 636 637def _flat_simplify(expression, simplifier, root=True): 638 if root or not expression.same_parent: 639 operands = [] 640 queue = deque(expression.flatten(unnest=False)) 641 size = len(queue) 642 643 while queue: 644 a = queue.popleft() 645 646 for b in queue: 647 result = simplifier(expression, a, b) 648 649 if result and result is not expression: 650 queue.remove(b) 651 queue.appendleft(result) 652 break 653 else: 654 operands.append(a) 655 656 if len(operands) < size: 657 return functools.reduce( 658 lambda a, b: expression.__class__(this=a, expression=b), operands 659 ) 660 return expression
16def simplify(expression): 17 """ 18 Rewrite sqlglot AST to simplify expressions. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 23 >>> simplify(expression).sql() 24 'TRUE' 25 26 Args: 27 expression (sqlglot.Expression): expression to simplify 28 Returns: 29 sqlglot.Expression: simplified expression 30 """ 31 32 generate = cached_generator() 33 34 # group by expressions cannot be simplified, for example 35 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 36 # the projection must exactly match the group by key 37 for group in expression.find_all(exp.Group): 38 select = group.parent 39 groups = set(group.expressions) 40 group.meta[FINAL] = True 41 42 for e in select.selects: 43 for node, *_ in e.walk(): 44 if node in groups: 45 e.meta[FINAL] = True 46 break 47 48 having = select.args.get("having") 49 if having: 50 for node, *_ in having.walk(): 51 if node in groups: 52 having.meta[FINAL] = True 53 break 54 55 def _simplify(expression, root=True): 56 if expression.meta.get(FINAL): 57 return expression 58 59 # Pre-order transformations 60 node = expression 61 node = rewrite_between(node) 62 node = uniq_sort(node, generate, root) 63 node = absorb_and_eliminate(node, root) 64 node = simplify_concat(node) 65 66 exp.replace_children(node, lambda e: _simplify(e, False)) 67 68 # Post-order transformations 69 node = simplify_not(node) 70 node = flatten(node) 71 node = simplify_connectors(node, root) 72 node = remove_compliments(node, root) 73 node = simplify_coalesce(node) 74 node.parent = expression.parent 75 node = simplify_literals(node, root) 76 node = simplify_parens(node) 77 78 if root: 79 expression.replace(node) 80 81 return node 82 83 expression = while_changing(expression, _simplify) 84 remove_where_true(expression) 85 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
88def rewrite_between(expression: exp.Expression) -> exp.Expression: 89 """Rewrite x between y and z to x >= y AND x <= z. 90 91 This is done because comparison simplification is only done on lt/lte/gt/gte. 92 """ 93 if isinstance(expression, exp.Between): 94 return exp.and_( 95 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 96 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 97 copy=False, 98 ) 99 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
102def simplify_not(expression): 103 """ 104 Demorgan's Law 105 NOT (x OR y) -> NOT x AND NOT y 106 NOT (x AND y) -> NOT x OR NOT y 107 """ 108 if isinstance(expression, exp.Not): 109 if is_null(expression.this): 110 return exp.null() 111 if isinstance(expression.this, exp.Paren): 112 condition = expression.this.unnest() 113 if isinstance(condition, exp.And): 114 return exp.or_( 115 exp.not_(condition.left, copy=False), 116 exp.not_(condition.right, copy=False), 117 copy=False, 118 ) 119 if isinstance(condition, exp.Or): 120 return exp.and_( 121 exp.not_(condition.left, copy=False), 122 exp.not_(condition.right, copy=False), 123 copy=False, 124 ) 125 if is_null(condition): 126 return exp.null() 127 if always_true(expression.this): 128 return exp.false() 129 if is_false(expression.this): 130 return exp.true() 131 if isinstance(expression.this, exp.Not): 132 # double negation 133 # NOT NOT x -> x 134 return expression.this.this 135 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
138def flatten(expression): 139 """ 140 A AND (B AND C) -> A AND B AND C 141 A OR (B OR C) -> A OR B OR C 142 """ 143 if isinstance(expression, exp.Connector): 144 for node in expression.args.values(): 145 child = node.unnest() 146 if isinstance(child, expression.__class__): 147 node.replace(child) 148 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
151def simplify_connectors(expression, root=True): 152 def _simplify_connectors(expression, left, right): 153 if left == right: 154 return left 155 if isinstance(expression, exp.And): 156 if is_false(left) or is_false(right): 157 return exp.false() 158 if is_null(left) or is_null(right): 159 return exp.null() 160 if always_true(left) and always_true(right): 161 return exp.true() 162 if always_true(left): 163 return right 164 if always_true(right): 165 return left 166 return _simplify_comparison(expression, left, right) 167 elif isinstance(expression, exp.Or): 168 if always_true(left) or always_true(right): 169 return exp.true() 170 if is_false(left) and is_false(right): 171 return exp.false() 172 if ( 173 (is_null(left) and is_null(right)) 174 or (is_null(left) and is_false(right)) 175 or (is_false(left) and is_null(right)) 176 ): 177 return exp.null() 178 if is_false(left): 179 return right 180 if is_false(right): 181 return left 182 return _simplify_comparison(expression, left, right, or_=True) 183 184 if isinstance(expression, exp.Connector): 185 return _flat_simplify(expression, _simplify_connectors, root) 186 return expression
269def remove_compliments(expression, root=True): 270 """ 271 Removing compliments. 272 273 A AND NOT A -> FALSE 274 A OR NOT A -> TRUE 275 """ 276 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 277 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 278 279 for a, b in itertools.permutations(expression.flatten(), 2): 280 if is_complement(a, b): 281 return compliment 282 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
285def uniq_sort(expression, generate, root=True): 286 """ 287 Uniq and sort a connector. 288 289 C AND A AND B AND B -> A AND B AND C 290 """ 291 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 292 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 293 flattened = tuple(expression.flatten()) 294 deduped = {generate(e): e for e in flattened} 295 arr = tuple(deduped.items()) 296 297 # check if the operands are already sorted, if not sort them 298 # A AND C AND B -> A AND B AND C 299 for i, (sql, e) in enumerate(arr[1:]): 300 if sql < arr[i][0]: 301 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 302 break 303 else: 304 # we didn't have to sort but maybe we need to dedup 305 if len(deduped) < len(flattened): 306 expression = result_func(*deduped.values(), copy=False) 307 308 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
311def absorb_and_eliminate(expression, root=True): 312 """ 313 absorption: 314 A AND (A OR B) -> A 315 A OR (A AND B) -> A 316 A AND (NOT A OR B) -> A AND B 317 A OR (NOT A AND B) -> A OR B 318 elimination: 319 (A AND B) OR (A AND NOT B) -> A 320 (A OR B) AND (A OR NOT B) -> A 321 """ 322 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 323 kind = exp.Or if isinstance(expression, exp.And) else exp.And 324 325 for a, b in itertools.permutations(expression.flatten(), 2): 326 if isinstance(a, kind): 327 aa, ab = a.unnest_operands() 328 329 # absorb 330 if is_complement(b, aa): 331 aa.replace(exp.true() if kind == exp.And else exp.false()) 332 elif is_complement(b, ab): 333 ab.replace(exp.true() if kind == exp.And else exp.false()) 334 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 335 a.replace(exp.false() if kind == exp.And else exp.true()) 336 elif isinstance(b, kind): 337 # eliminate 338 rhs = b.unnest_operands() 339 ba, bb = rhs 340 341 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 342 a.replace(aa) 343 b.replace(aa) 344 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 345 a.replace(ab) 346 b.replace(ab) 347 348 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
351def simplify_literals(expression, root=True): 352 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 353 return _flat_simplify(expression, _simplify_binary, root) 354 355 if isinstance(expression, exp.Neg): 356 this = expression.this 357 if this.is_number: 358 value = this.name 359 if value[0] == "-": 360 return exp.Literal.number(value[1:]) 361 return exp.Literal.number(f"-{value}") 362 363 return expression
426def simplify_parens(expression): 427 if not isinstance(expression, exp.Paren): 428 return expression 429 430 this = expression.this 431 parent = expression.parent 432 433 if not isinstance(this, exp.Select) and ( 434 not isinstance(parent, (exp.Condition, exp.Binary)) 435 or isinstance(parent, exp.Paren) 436 or not isinstance(this, exp.Binary) 437 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 438 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 439 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 440 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 441 ): 442 return this 443 return expression
453def simplify_coalesce(expression): 454 # COALESCE(x) -> x 455 if ( 456 isinstance(expression, exp.Coalesce) 457 and not expression.expressions 458 # COALESCE is also used as a Spark partitioning hint 459 and not isinstance(expression.parent, exp.Hint) 460 ): 461 return expression.this 462 463 if not isinstance(expression, COMPARISONS): 464 return expression 465 466 if isinstance(expression.left, exp.Coalesce): 467 coalesce = expression.left 468 other = expression.right 469 elif isinstance(expression.right, exp.Coalesce): 470 coalesce = expression.right 471 other = expression.left 472 else: 473 return expression 474 475 # This transformation is valid for non-constants, 476 # but it really only does anything if they are both constants. 477 if not isinstance(other, CONSTANTS): 478 return expression 479 480 # Find the first constant arg 481 for arg_index, arg in enumerate(coalesce.expressions): 482 if isinstance(arg, CONSTANTS): 483 break 484 else: 485 return expression 486 487 coalesce.set("expressions", coalesce.expressions[:arg_index]) 488 489 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 490 # since we already remove COALESCE at the top of this function. 491 coalesce = coalesce if coalesce.expressions else coalesce.this 492 493 # This expression is more complex than when we started, but it will get simplified further 494 return exp.paren( 495 exp.or_( 496 exp.and_( 497 coalesce.is_(exp.null()).not_(copy=False), 498 expression.copy(), 499 copy=False, 500 ), 501 exp.and_( 502 coalesce.is_(exp.null()), 503 type(expression)(this=arg.copy(), expression=other.copy()), 504 copy=False, 505 ), 506 copy=False, 507 ) 508 )
515def simplify_concat(expression): 516 """Reduces all groups that contain string literals by concatenating them.""" 517 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 518 return expression 519 520 new_args = [] 521 for is_string_group, group in itertools.groupby( 522 expression.expressions or expression.flatten(), lambda e: e.is_string 523 ): 524 if is_string_group: 525 new_args.append(exp.Literal.string("".join(string.name for string in group))) 526 else: 527 new_args.extend(group) 528 529 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 530 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 531 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
545def remove_where_true(expression): 546 for where in expression.find_all(exp.Where): 547 if always_true(where.this): 548 where.parent.set("where", None) 549 for join in expression.find_all(exp.Join): 550 if ( 551 always_true(join.args.get("on")) 552 and not join.args.get("using") 553 and not join.args.get("method") 554 and (join.side, join.kind) in JOINS 555 ): 556 join.set("on", None) 557 join.set("side", None) 558 join.set("kind", "CROSS")
579def eval_boolean(expression, a, b): 580 if isinstance(expression, (exp.EQ, exp.Is)): 581 return boolean_literal(a == b) 582 if isinstance(expression, exp.NEQ): 583 return boolean_literal(a != b) 584 if isinstance(expression, exp.GT): 585 return boolean_literal(a > b) 586 if isinstance(expression, exp.GTE): 587 return boolean_literal(a >= b) 588 if isinstance(expression, exp.LT): 589 return boolean_literal(a < b) 590 if isinstance(expression, exp.LTE): 591 return boolean_literal(a <= b) 592 return None
595def extract_date(cast): 596 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 597 # so in that case we can't extract the date. 598 try: 599 if cast.args["to"].this == exp.DataType.Type.DATE: 600 return datetime.date.fromisoformat(cast.name) 601 if cast.args["to"].this == exp.DataType.Type.DATETIME: 602 return datetime.datetime.fromisoformat(cast.name) 603 except ValueError: 604 return None
607def extract_interval(interval): 608 try: 609 from dateutil.relativedelta import relativedelta # type: ignore 610 except ModuleNotFoundError: 611 return None 612 613 n = int(interval.name) 614 unit = interval.text("unit").lower() 615 616 if unit == "year": 617 return relativedelta(years=n) 618 if unit == "month": 619 return relativedelta(months=n) 620 if unit == "week": 621 return relativedelta(weeks=n) 622 if unit == "day": 623 return relativedelta(days=n) 624 return None