sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4from collections import deque 5from decimal import Decimal 6 7from sqlglot import exp 8from sqlglot.generator import cached_generator 9from sqlglot.helper import first, while_changing 10 11 12def simplify(expression): 13 """ 14 Rewrite sqlglot AST to simplify expressions. 15 16 Example: 17 >>> import sqlglot 18 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 19 >>> simplify(expression).sql() 20 'TRUE' 21 22 Args: 23 expression (sqlglot.Expression): expression to simplify 24 Returns: 25 sqlglot.Expression: simplified expression 26 """ 27 28 generate = cached_generator() 29 30 def _simplify(expression, root=True): 31 node = expression 32 node = rewrite_between(node) 33 node = uniq_sort(node, generate, root) 34 node = absorb_and_eliminate(node, root) 35 exp.replace_children(node, lambda e: _simplify(e, False)) 36 node = simplify_not(node) 37 node = flatten(node) 38 node = simplify_connectors(node, root) 39 node = remove_compliments(node, root) 40 node.parent = expression.parent 41 node = simplify_literals(node, root) 42 node = simplify_parens(node) 43 if root: 44 expression.replace(node) 45 return node 46 47 expression = while_changing(expression, _simplify) 48 remove_where_true(expression) 49 return expression 50 51 52def rewrite_between(expression: exp.Expression) -> exp.Expression: 53 """Rewrite x between y and z to x >= y AND x <= z. 54 55 This is done because comparison simplification is only done on lt/lte/gt/gte. 56 """ 57 if isinstance(expression, exp.Between): 58 return exp.and_( 59 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 60 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 61 copy=False, 62 ) 63 return expression 64 65 66def simplify_not(expression): 67 """ 68 Demorgan's Law 69 NOT (x OR y) -> NOT x AND NOT y 70 NOT (x AND y) -> NOT x OR NOT y 71 """ 72 if isinstance(expression, exp.Not): 73 if is_null(expression.this): 74 return exp.null() 75 if isinstance(expression.this, exp.Paren): 76 condition = expression.this.unnest() 77 if isinstance(condition, exp.And): 78 return exp.or_( 79 exp.not_(condition.left, copy=False), 80 exp.not_(condition.right, copy=False), 81 copy=False, 82 ) 83 if isinstance(condition, exp.Or): 84 return exp.and_( 85 exp.not_(condition.left, copy=False), 86 exp.not_(condition.right, copy=False), 87 copy=False, 88 ) 89 if is_null(condition): 90 return exp.null() 91 if always_true(expression.this): 92 return exp.false() 93 if is_false(expression.this): 94 return exp.true() 95 if isinstance(expression.this, exp.Not): 96 # double negation 97 # NOT NOT x -> x 98 return expression.this.this 99 return expression 100 101 102def flatten(expression): 103 """ 104 A AND (B AND C) -> A AND B AND C 105 A OR (B OR C) -> A OR B OR C 106 """ 107 if isinstance(expression, exp.Connector): 108 for node in expression.args.values(): 109 child = node.unnest() 110 if isinstance(child, expression.__class__): 111 node.replace(child) 112 return expression 113 114 115def simplify_connectors(expression, root=True): 116 def _simplify_connectors(expression, left, right): 117 if left == right: 118 return left 119 if isinstance(expression, exp.And): 120 if is_false(left) or is_false(right): 121 return exp.false() 122 if is_null(left) or is_null(right): 123 return exp.null() 124 if always_true(left) and always_true(right): 125 return exp.true() 126 if always_true(left): 127 return right 128 if always_true(right): 129 return left 130 return _simplify_comparison(expression, left, right) 131 elif isinstance(expression, exp.Or): 132 if always_true(left) or always_true(right): 133 return exp.true() 134 if is_false(left) and is_false(right): 135 return exp.false() 136 if ( 137 (is_null(left) and is_null(right)) 138 or (is_null(left) and is_false(right)) 139 or (is_false(left) and is_null(right)) 140 ): 141 return exp.null() 142 if is_false(left): 143 return right 144 if is_false(right): 145 return left 146 return _simplify_comparison(expression, left, right, or_=True) 147 148 if isinstance(expression, exp.Connector): 149 return _flat_simplify(expression, _simplify_connectors, root) 150 return expression 151 152 153LT_LTE = (exp.LT, exp.LTE) 154GT_GTE = (exp.GT, exp.GTE) 155 156COMPARISONS = ( 157 *LT_LTE, 158 *GT_GTE, 159 exp.EQ, 160 exp.NEQ, 161) 162 163INVERSE_COMPARISONS = { 164 exp.LT: exp.GT, 165 exp.GT: exp.LT, 166 exp.LTE: exp.GTE, 167 exp.GTE: exp.LTE, 168} 169 170 171def _simplify_comparison(expression, left, right, or_=False): 172 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 173 ll, lr = left.args.values() 174 rl, rr = right.args.values() 175 176 largs = {ll, lr} 177 rargs = {rl, rr} 178 179 matching = largs & rargs 180 columns = {m for m in matching if isinstance(m, exp.Column)} 181 182 if matching and columns: 183 try: 184 l = first(largs - columns) 185 r = first(rargs - columns) 186 except StopIteration: 187 return expression 188 189 # make sure the comparison is always of the form x > 1 instead of 1 < x 190 if left.__class__ in INVERSE_COMPARISONS and l == ll: 191 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 192 if right.__class__ in INVERSE_COMPARISONS and r == rl: 193 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 194 195 if l.is_number and r.is_number: 196 l = float(l.name) 197 r = float(r.name) 198 elif l.is_string and r.is_string: 199 l = l.name 200 r = r.name 201 else: 202 return None 203 204 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 205 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 206 return left if (av > bv if or_ else av <= bv) else right 207 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 208 return left if (av < bv if or_ else av >= bv) else right 209 210 # we can't ever shortcut to true because the column could be null 211 if not or_: 212 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 213 if av <= bv: 214 return exp.false() 215 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 216 if av >= bv: 217 return exp.false() 218 elif isinstance(a, exp.EQ): 219 if isinstance(b, exp.LT): 220 return exp.false() if av >= bv else a 221 if isinstance(b, exp.LTE): 222 return exp.false() if av > bv else a 223 if isinstance(b, exp.GT): 224 return exp.false() if av <= bv else a 225 if isinstance(b, exp.GTE): 226 return exp.false() if av < bv else a 227 if isinstance(b, exp.NEQ): 228 return exp.false() if av == bv else a 229 return None 230 231 232def remove_compliments(expression, root=True): 233 """ 234 Removing compliments. 235 236 A AND NOT A -> FALSE 237 A OR NOT A -> TRUE 238 """ 239 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 240 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 241 242 for a, b in itertools.permutations(expression.flatten(), 2): 243 if is_complement(a, b): 244 return compliment 245 return expression 246 247 248def uniq_sort(expression, generate, root=True): 249 """ 250 Uniq and sort a connector. 251 252 C AND A AND B AND B -> A AND B AND C 253 """ 254 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 255 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 256 flattened = tuple(expression.flatten()) 257 deduped = {generate(e): e for e in flattened} 258 arr = tuple(deduped.items()) 259 260 # check if the operands are already sorted, if not sort them 261 # A AND C AND B -> A AND B AND C 262 for i, (sql, e) in enumerate(arr[1:]): 263 if sql < arr[i][0]: 264 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 265 break 266 else: 267 # we didn't have to sort but maybe we need to dedup 268 if len(deduped) < len(flattened): 269 expression = result_func(*deduped.values(), copy=False) 270 271 return expression 272 273 274def absorb_and_eliminate(expression, root=True): 275 """ 276 absorption: 277 A AND (A OR B) -> A 278 A OR (A AND B) -> A 279 A AND (NOT A OR B) -> A AND B 280 A OR (NOT A AND B) -> A OR B 281 elimination: 282 (A AND B) OR (A AND NOT B) -> A 283 (A OR B) AND (A OR NOT B) -> A 284 """ 285 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 286 kind = exp.Or if isinstance(expression, exp.And) else exp.And 287 288 for a, b in itertools.permutations(expression.flatten(), 2): 289 if isinstance(a, kind): 290 aa, ab = a.unnest_operands() 291 292 # absorb 293 if is_complement(b, aa): 294 aa.replace(exp.true() if kind == exp.And else exp.false()) 295 elif is_complement(b, ab): 296 ab.replace(exp.true() if kind == exp.And else exp.false()) 297 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 298 a.replace(exp.false() if kind == exp.And else exp.true()) 299 elif isinstance(b, kind): 300 # eliminate 301 rhs = b.unnest_operands() 302 ba, bb = rhs 303 304 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 305 a.replace(aa) 306 b.replace(aa) 307 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 308 a.replace(ab) 309 b.replace(ab) 310 311 return expression 312 313 314def simplify_literals(expression, root=True): 315 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 316 return _flat_simplify(expression, _simplify_binary, root) 317 elif isinstance(expression, exp.Neg): 318 this = expression.this 319 if this.is_number: 320 value = this.name 321 if value[0] == "-": 322 return exp.Literal.number(value[1:]) 323 return exp.Literal.number(f"-{value}") 324 325 return expression 326 327 328def _simplify_binary(expression, a, b): 329 if isinstance(expression, exp.Is): 330 if isinstance(b, exp.Not): 331 c = b.this 332 not_ = True 333 else: 334 c = b 335 not_ = False 336 337 if is_null(c): 338 if isinstance(a, exp.Literal): 339 return exp.true() if not_ else exp.false() 340 if is_null(a): 341 return exp.false() if not_ else exp.true() 342 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 343 return None 344 elif is_null(a) or is_null(b): 345 return exp.null() 346 347 if a.is_number and b.is_number: 348 a = int(a.name) if a.is_int else Decimal(a.name) 349 b = int(b.name) if b.is_int else Decimal(b.name) 350 351 if isinstance(expression, exp.Add): 352 return exp.Literal.number(a + b) 353 if isinstance(expression, exp.Sub): 354 return exp.Literal.number(a - b) 355 if isinstance(expression, exp.Mul): 356 return exp.Literal.number(a * b) 357 if isinstance(expression, exp.Div): 358 # engines have differing int div behavior so intdiv is not safe 359 if isinstance(a, int) and isinstance(b, int): 360 return None 361 return exp.Literal.number(a / b) 362 363 boolean = eval_boolean(expression, a, b) 364 365 if boolean: 366 return boolean 367 elif a.is_string and b.is_string: 368 boolean = eval_boolean(expression, a.this, b.this) 369 370 if boolean: 371 return boolean 372 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 373 a, b = extract_date(a), extract_interval(b) 374 if a and b: 375 if isinstance(expression, exp.Add): 376 return date_literal(a + b) 377 if isinstance(expression, exp.Sub): 378 return date_literal(a - b) 379 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 380 a, b = extract_interval(a), extract_date(b) 381 # you cannot subtract a date from an interval 382 if a and b and isinstance(expression, exp.Add): 383 return date_literal(a + b) 384 385 return None 386 387 388def simplify_parens(expression): 389 if not isinstance(expression, exp.Paren): 390 return expression 391 392 this = expression.this 393 parent = expression.parent 394 395 if not isinstance(this, exp.Select) and ( 396 not isinstance(parent, (exp.Condition, exp.Binary)) 397 or isinstance(this, exp.Predicate) 398 or not isinstance(this, exp.Binary) 399 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 400 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 401 ): 402 return expression.this 403 return expression 404 405 406def remove_where_true(expression): 407 for where in expression.find_all(exp.Where): 408 if always_true(where.this): 409 where.parent.set("where", None) 410 for join in expression.find_all(exp.Join): 411 if ( 412 always_true(join.args.get("on")) 413 and not join.args.get("using") 414 and not join.args.get("method") 415 ): 416 join.set("on", None) 417 join.set("side", None) 418 join.set("kind", "CROSS") 419 420 421def always_true(expression): 422 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 423 expression, exp.Literal 424 ) 425 426 427def is_complement(a, b): 428 return isinstance(b, exp.Not) and b.this == a 429 430 431def is_false(a: exp.Expression) -> bool: 432 return type(a) is exp.Boolean and not a.this 433 434 435def is_null(a: exp.Expression) -> bool: 436 return type(a) is exp.Null 437 438 439def eval_boolean(expression, a, b): 440 if isinstance(expression, (exp.EQ, exp.Is)): 441 return boolean_literal(a == b) 442 if isinstance(expression, exp.NEQ): 443 return boolean_literal(a != b) 444 if isinstance(expression, exp.GT): 445 return boolean_literal(a > b) 446 if isinstance(expression, exp.GTE): 447 return boolean_literal(a >= b) 448 if isinstance(expression, exp.LT): 449 return boolean_literal(a < b) 450 if isinstance(expression, exp.LTE): 451 return boolean_literal(a <= b) 452 return None 453 454 455def extract_date(cast): 456 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 457 # so in that case we can't extract the date. 458 try: 459 if cast.args["to"].this == exp.DataType.Type.DATE: 460 return datetime.date.fromisoformat(cast.name) 461 if cast.args["to"].this == exp.DataType.Type.DATETIME: 462 return datetime.datetime.fromisoformat(cast.name) 463 except ValueError: 464 return None 465 466 467def extract_interval(interval): 468 try: 469 from dateutil.relativedelta import relativedelta # type: ignore 470 except ModuleNotFoundError: 471 return None 472 473 n = int(interval.name) 474 unit = interval.text("unit").lower() 475 476 if unit == "year": 477 return relativedelta(years=n) 478 if unit == "month": 479 return relativedelta(months=n) 480 if unit == "week": 481 return relativedelta(weeks=n) 482 if unit == "day": 483 return relativedelta(days=n) 484 return None 485 486 487def date_literal(date): 488 return exp.cast( 489 exp.Literal.string(date), 490 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 491 ) 492 493 494def boolean_literal(condition): 495 return exp.true() if condition else exp.false() 496 497 498def _flat_simplify(expression, simplifier, root=True): 499 if root or not expression.same_parent: 500 operands = [] 501 queue = deque(expression.flatten(unnest=False)) 502 size = len(queue) 503 504 while queue: 505 a = queue.popleft() 506 507 for b in queue: 508 result = simplifier(expression, a, b) 509 510 if result: 511 queue.remove(b) 512 queue.appendleft(result) 513 break 514 else: 515 operands.append(a) 516 517 if len(operands) < size: 518 return functools.reduce( 519 lambda a, b: expression.__class__(this=a, expression=b), operands 520 ) 521 return expression
def
simplify(expression):
13def simplify(expression): 14 """ 15 Rewrite sqlglot AST to simplify expressions. 16 17 Example: 18 >>> import sqlglot 19 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 20 >>> simplify(expression).sql() 21 'TRUE' 22 23 Args: 24 expression (sqlglot.Expression): expression to simplify 25 Returns: 26 sqlglot.Expression: simplified expression 27 """ 28 29 generate = cached_generator() 30 31 def _simplify(expression, root=True): 32 node = expression 33 node = rewrite_between(node) 34 node = uniq_sort(node, generate, root) 35 node = absorb_and_eliminate(node, root) 36 exp.replace_children(node, lambda e: _simplify(e, False)) 37 node = simplify_not(node) 38 node = flatten(node) 39 node = simplify_connectors(node, root) 40 node = remove_compliments(node, root) 41 node.parent = expression.parent 42 node = simplify_literals(node, root) 43 node = simplify_parens(node) 44 if root: 45 expression.replace(node) 46 return node 47 48 expression = while_changing(expression, _simplify) 49 remove_where_true(expression) 50 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
53def rewrite_between(expression: exp.Expression) -> exp.Expression: 54 """Rewrite x between y and z to x >= y AND x <= z. 55 56 This is done because comparison simplification is only done on lt/lte/gt/gte. 57 """ 58 if isinstance(expression, exp.Between): 59 return exp.and_( 60 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 61 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 62 copy=False, 63 ) 64 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
def
simplify_not(expression):
67def simplify_not(expression): 68 """ 69 Demorgan's Law 70 NOT (x OR y) -> NOT x AND NOT y 71 NOT (x AND y) -> NOT x OR NOT y 72 """ 73 if isinstance(expression, exp.Not): 74 if is_null(expression.this): 75 return exp.null() 76 if isinstance(expression.this, exp.Paren): 77 condition = expression.this.unnest() 78 if isinstance(condition, exp.And): 79 return exp.or_( 80 exp.not_(condition.left, copy=False), 81 exp.not_(condition.right, copy=False), 82 copy=False, 83 ) 84 if isinstance(condition, exp.Or): 85 return exp.and_( 86 exp.not_(condition.left, copy=False), 87 exp.not_(condition.right, copy=False), 88 copy=False, 89 ) 90 if is_null(condition): 91 return exp.null() 92 if always_true(expression.this): 93 return exp.false() 94 if is_false(expression.this): 95 return exp.true() 96 if isinstance(expression.this, exp.Not): 97 # double negation 98 # NOT NOT x -> x 99 return expression.this.this 100 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
def
flatten(expression):
103def flatten(expression): 104 """ 105 A AND (B AND C) -> A AND B AND C 106 A OR (B OR C) -> A OR B OR C 107 """ 108 if isinstance(expression, exp.Connector): 109 for node in expression.args.values(): 110 child = node.unnest() 111 if isinstance(child, expression.__class__): 112 node.replace(child) 113 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
def
simplify_connectors(expression, root=True):
116def simplify_connectors(expression, root=True): 117 def _simplify_connectors(expression, left, right): 118 if left == right: 119 return left 120 if isinstance(expression, exp.And): 121 if is_false(left) or is_false(right): 122 return exp.false() 123 if is_null(left) or is_null(right): 124 return exp.null() 125 if always_true(left) and always_true(right): 126 return exp.true() 127 if always_true(left): 128 return right 129 if always_true(right): 130 return left 131 return _simplify_comparison(expression, left, right) 132 elif isinstance(expression, exp.Or): 133 if always_true(left) or always_true(right): 134 return exp.true() 135 if is_false(left) and is_false(right): 136 return exp.false() 137 if ( 138 (is_null(left) and is_null(right)) 139 or (is_null(left) and is_false(right)) 140 or (is_false(left) and is_null(right)) 141 ): 142 return exp.null() 143 if is_false(left): 144 return right 145 if is_false(right): 146 return left 147 return _simplify_comparison(expression, left, right, or_=True) 148 149 if isinstance(expression, exp.Connector): 150 return _flat_simplify(expression, _simplify_connectors, root) 151 return expression
LT_LTE =
(<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE =
(<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
COMPARISONS =
(<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>)
INVERSE_COMPARISONS =
{<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def
remove_compliments(expression, root=True):
233def remove_compliments(expression, root=True): 234 """ 235 Removing compliments. 236 237 A AND NOT A -> FALSE 238 A OR NOT A -> TRUE 239 """ 240 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 241 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 242 243 for a, b in itertools.permutations(expression.flatten(), 2): 244 if is_complement(a, b): 245 return compliment 246 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
def
uniq_sort(expression, generate, root=True):
249def uniq_sort(expression, generate, root=True): 250 """ 251 Uniq and sort a connector. 252 253 C AND A AND B AND B -> A AND B AND C 254 """ 255 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 256 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 257 flattened = tuple(expression.flatten()) 258 deduped = {generate(e): e for e in flattened} 259 arr = tuple(deduped.items()) 260 261 # check if the operands are already sorted, if not sort them 262 # A AND C AND B -> A AND B AND C 263 for i, (sql, e) in enumerate(arr[1:]): 264 if sql < arr[i][0]: 265 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 266 break 267 else: 268 # we didn't have to sort but maybe we need to dedup 269 if len(deduped) < len(flattened): 270 expression = result_func(*deduped.values(), copy=False) 271 272 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
def
absorb_and_eliminate(expression, root=True):
275def absorb_and_eliminate(expression, root=True): 276 """ 277 absorption: 278 A AND (A OR B) -> A 279 A OR (A AND B) -> A 280 A AND (NOT A OR B) -> A AND B 281 A OR (NOT A AND B) -> A OR B 282 elimination: 283 (A AND B) OR (A AND NOT B) -> A 284 (A OR B) AND (A OR NOT B) -> A 285 """ 286 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 287 kind = exp.Or if isinstance(expression, exp.And) else exp.And 288 289 for a, b in itertools.permutations(expression.flatten(), 2): 290 if isinstance(a, kind): 291 aa, ab = a.unnest_operands() 292 293 # absorb 294 if is_complement(b, aa): 295 aa.replace(exp.true() if kind == exp.And else exp.false()) 296 elif is_complement(b, ab): 297 ab.replace(exp.true() if kind == exp.And else exp.false()) 298 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 299 a.replace(exp.false() if kind == exp.And else exp.true()) 300 elif isinstance(b, kind): 301 # eliminate 302 rhs = b.unnest_operands() 303 ba, bb = rhs 304 305 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 306 a.replace(aa) 307 b.replace(aa) 308 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 309 a.replace(ab) 310 b.replace(ab) 311 312 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
def
simplify_literals(expression, root=True):
315def simplify_literals(expression, root=True): 316 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 317 return _flat_simplify(expression, _simplify_binary, root) 318 elif isinstance(expression, exp.Neg): 319 this = expression.this 320 if this.is_number: 321 value = this.name 322 if value[0] == "-": 323 return exp.Literal.number(value[1:]) 324 return exp.Literal.number(f"-{value}") 325 326 return expression
def
simplify_parens(expression):
389def simplify_parens(expression): 390 if not isinstance(expression, exp.Paren): 391 return expression 392 393 this = expression.this 394 parent = expression.parent 395 396 if not isinstance(this, exp.Select) and ( 397 not isinstance(parent, (exp.Condition, exp.Binary)) 398 or isinstance(this, exp.Predicate) 399 or not isinstance(this, exp.Binary) 400 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 401 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 402 ): 403 return expression.this 404 return expression
def
remove_where_true(expression):
407def remove_where_true(expression): 408 for where in expression.find_all(exp.Where): 409 if always_true(where.this): 410 where.parent.set("where", None) 411 for join in expression.find_all(exp.Join): 412 if ( 413 always_true(join.args.get("on")) 414 and not join.args.get("using") 415 and not join.args.get("method") 416 ): 417 join.set("on", None) 418 join.set("side", None) 419 join.set("kind", "CROSS")
def
always_true(expression):
def
is_complement(a, b):
def
eval_boolean(expression, a, b):
440def eval_boolean(expression, a, b): 441 if isinstance(expression, (exp.EQ, exp.Is)): 442 return boolean_literal(a == b) 443 if isinstance(expression, exp.NEQ): 444 return boolean_literal(a != b) 445 if isinstance(expression, exp.GT): 446 return boolean_literal(a > b) 447 if isinstance(expression, exp.GTE): 448 return boolean_literal(a >= b) 449 if isinstance(expression, exp.LT): 450 return boolean_literal(a < b) 451 if isinstance(expression, exp.LTE): 452 return boolean_literal(a <= b) 453 return None
def
extract_date(cast):
456def extract_date(cast): 457 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 458 # so in that case we can't extract the date. 459 try: 460 if cast.args["to"].this == exp.DataType.Type.DATE: 461 return datetime.date.fromisoformat(cast.name) 462 if cast.args["to"].this == exp.DataType.Type.DATETIME: 463 return datetime.datetime.fromisoformat(cast.name) 464 except ValueError: 465 return None
def
extract_interval(interval):
468def extract_interval(interval): 469 try: 470 from dateutil.relativedelta import relativedelta # type: ignore 471 except ModuleNotFoundError: 472 return None 473 474 n = int(interval.name) 475 unit = interval.text("unit").lower() 476 477 if unit == "year": 478 return relativedelta(years=n) 479 if unit == "month": 480 return relativedelta(months=n) 481 if unit == "week": 482 return relativedelta(weeks=n) 483 if unit == "day": 484 return relativedelta(days=n) 485 return None
def
date_literal(date):
def
boolean_literal(condition):