From beba715b97dd2349e01dde9b077d2535680ebdca Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 10 May 2023 08:44:58 +0200 Subject: Merging upstream version 12.2.0. Signed-off-by: Daniel Baumann --- docs/sqlglot/optimizer/simplify.html | 1366 +++++++++++++++++----------------- 1 file changed, 692 insertions(+), 674 deletions(-) (limited to 'docs/sqlglot/optimizer/simplify.html') diff --git a/docs/sqlglot/optimizer/simplify.html b/docs/sqlglot/optimizer/simplify.html index 5ba4a5e..598977c 100644 --- a/docs/sqlglot/optimizer/simplify.html +++ b/docs/sqlglot/optimizer/simplify.html @@ -175,449 +175,458 @@ 60 return exp.and_( 61 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 62 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), - 63 ) - 64 return expression - 65 + 63 copy=False, + 64 ) + 65 return expression 66 - 67def simplify_not(expression): - 68 """ - 69 Demorgan's Law - 70 NOT (x OR y) -> NOT x AND NOT y - 71 NOT (x AND y) -> NOT x OR NOT y - 72 """ - 73 if isinstance(expression, exp.Not): - 74 if is_null(expression.this): - 75 return exp.null() - 76 if isinstance(expression.this, exp.Paren): - 77 condition = expression.this.unnest() - 78 if isinstance(condition, exp.And): - 79 return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) - 80 if isinstance(condition, exp.Or): - 81 return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) - 82 if is_null(condition): - 83 return exp.null() - 84 if always_true(expression.this): - 85 return exp.false() - 86 if is_false(expression.this): - 87 return exp.true() - 88 if isinstance(expression.this, exp.Not): - 89 # double negation - 90 # NOT NOT x -> x - 91 return expression.this.this - 92 return expression - 93 - 94 - 95def flatten(expression): - 96 """ - 97 A AND (B AND C) -> A AND B AND C - 98 A OR (B OR C) -> A OR B OR C - 99 """ -100 if isinstance(expression, exp.Connector): -101 for node in expression.args.values(): -102 child = node.unnest() -103 if isinstance(child, expression.__class__): -104 node.replace(child) -105 return expression -106 -107 -108def simplify_connectors(expression, root=True): -109 def _simplify_connectors(expression, left, right): -110 if left == right: -111 return left -112 if isinstance(expression, exp.And): -113 if is_false(left) or is_false(right): -114 return exp.false() -115 if is_null(left) or is_null(right): -116 return exp.null() -117 if always_true(left) and always_true(right): -118 return exp.true() -119 if always_true(left): -120 return right -121 if always_true(right): -122 return left -123 return _simplify_comparison(expression, left, right) -124 elif isinstance(expression, exp.Or): -125 if always_true(left) or always_true(right): -126 return exp.true() -127 if is_false(left) and is_false(right): -128 return exp.false() -129 if ( -130 (is_null(left) and is_null(right)) -131 or (is_null(left) and is_false(right)) -132 or (is_false(left) and is_null(right)) -133 ): -134 return exp.null() -135 if is_false(left): -136 return right -137 if is_false(right): -138 return left -139 return _simplify_comparison(expression, left, right, or_=True) -140 -141 if isinstance(expression, exp.Connector): -142 return _flat_simplify(expression, _simplify_connectors, root) -143 return expression -144 -145 -146LT_LTE = (exp.LT, exp.LTE) -147GT_GTE = (exp.GT, exp.GTE) -148 -149COMPARISONS = ( -150 *LT_LTE, -151 *GT_GTE, -152 exp.EQ, -153 exp.NEQ, -154) -155 -156INVERSE_COMPARISONS = { -157 exp.LT: exp.GT, -158 exp.GT: exp.LT, -159 exp.LTE: exp.GTE, -160 exp.GTE: exp.LTE, -161} -162 -163 -164def _simplify_comparison(expression, left, right, or_=False): -165 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): -166 ll, lr = left.args.values() -167 rl, rr = right.args.values() -168 -169 largs = {ll, lr} -170 rargs = {rl, rr} + 67 + 68def simplify_not(expression): + 69 """ + 70 Demorgan's Law + 71 NOT (x OR y) -> NOT x AND NOT y + 72 NOT (x AND y) -> NOT x OR NOT y + 73 """ + 74 if isinstance(expression, exp.Not): + 75 if is_null(expression.this): + 76 return exp.null() + 77 if isinstance(expression.this, exp.Paren): + 78 condition = expression.this.unnest() + 79 if isinstance(condition, exp.And): + 80 return exp.or_( + 81 exp.not_(condition.left, copy=False), + 82 exp.not_(condition.right, copy=False), + 83 copy=False, + 84 ) + 85 if isinstance(condition, exp.Or): + 86 return exp.and_( + 87 exp.not_(condition.left, copy=False), + 88 exp.not_(condition.right, copy=False), + 89 copy=False, + 90 ) + 91 if is_null(condition): + 92 return exp.null() + 93 if always_true(expression.this): + 94 return exp.false() + 95 if is_false(expression.this): + 96 return exp.true() + 97 if isinstance(expression.this, exp.Not): + 98 # double negation + 99 # NOT NOT x -> x +100 return expression.this.this +101 return expression +102 +103 +104def flatten(expression): +105 """ +106 A AND (B AND C) -> A AND B AND C +107 A OR (B OR C) -> A OR B OR C +108 """ +109 if isinstance(expression, exp.Connector): +110 for node in expression.args.values(): +111 child = node.unnest() +112 if isinstance(child, expression.__class__): +113 node.replace(child) +114 return expression +115 +116 +117def simplify_connectors(expression, root=True): +118 def _simplify_connectors(expression, left, right): +119 if left == right: +120 return left +121 if isinstance(expression, exp.And): +122 if is_false(left) or is_false(right): +123 return exp.false() +124 if is_null(left) or is_null(right): +125 return exp.null() +126 if always_true(left) and always_true(right): +127 return exp.true() +128 if always_true(left): +129 return right +130 if always_true(right): +131 return left +132 return _simplify_comparison(expression, left, right) +133 elif isinstance(expression, exp.Or): +134 if always_true(left) or always_true(right): +135 return exp.true() +136 if is_false(left) and is_false(right): +137 return exp.false() +138 if ( +139 (is_null(left) and is_null(right)) +140 or (is_null(left) and is_false(right)) +141 or (is_false(left) and is_null(right)) +142 ): +143 return exp.null() +144 if is_false(left): +145 return right +146 if is_false(right): +147 return left +148 return _simplify_comparison(expression, left, right, or_=True) +149 +150 if isinstance(expression, exp.Connector): +151 return _flat_simplify(expression, _simplify_connectors, root) +152 return expression +153 +154 +155LT_LTE = (exp.LT, exp.LTE) +156GT_GTE = (exp.GT, exp.GTE) +157 +158COMPARISONS = ( +159 *LT_LTE, +160 *GT_GTE, +161 exp.EQ, +162 exp.NEQ, +163) +164 +165INVERSE_COMPARISONS = { +166 exp.LT: exp.GT, +167 exp.GT: exp.LT, +168 exp.LTE: exp.GTE, +169 exp.GTE: exp.LTE, +170} 171 -172 matching = largs & rargs -173 columns = {m for m in matching if isinstance(m, exp.Column)} -174 -175 if matching and columns: -176 try: -177 l = first(largs - columns) -178 r = first(rargs - columns) -179 except StopIteration: -180 return expression -181 -182 # make sure the comparison is always of the form x > 1 instead of 1 < x -183 if left.__class__ in INVERSE_COMPARISONS and l == ll: -184 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) -185 if right.__class__ in INVERSE_COMPARISONS and r == rl: -186 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) -187 -188 if l.is_number and r.is_number: -189 l = float(l.name) -190 r = float(r.name) -191 elif l.is_string and r.is_string: -192 l = l.name -193 r = r.name -194 else: -195 return None +172 +173def _simplify_comparison(expression, left, right, or_=False): +174 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): +175 ll, lr = left.args.values() +176 rl, rr = right.args.values() +177 +178 largs = {ll, lr} +179 rargs = {rl, rr} +180 +181 matching = largs & rargs +182 columns = {m for m in matching if isinstance(m, exp.Column)} +183 +184 if matching and columns: +185 try: +186 l = first(largs - columns) +187 r = first(rargs - columns) +188 except StopIteration: +189 return expression +190 +191 # make sure the comparison is always of the form x > 1 instead of 1 < x +192 if left.__class__ in INVERSE_COMPARISONS and l == ll: +193 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) +194 if right.__class__ in INVERSE_COMPARISONS and r == rl: +195 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 196 -197 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): -198 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): -199 return left if (av > bv if or_ else av <= bv) else right -200 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): -201 return left if (av < bv if or_ else av >= bv) else right -202 -203 # we can't ever shortcut to true because the column could be null -204 if not or_: -205 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): -206 if av <= bv: -207 return exp.false() -208 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): -209 if av >= bv: -210 return exp.false() -211 elif isinstance(a, exp.EQ): -212 if isinstance(b, exp.LT): -213 return exp.false() if av >= bv else a -214 if isinstance(b, exp.LTE): -215 return exp.false() if av > bv else a -216 if isinstance(b, exp.GT): -217 return exp.false() if av <= bv else a -218 if isinstance(b, exp.GTE): -219 return exp.false() if av < bv else a -220 if isinstance(b, exp.NEQ): -221 return exp.false() if av == bv else a -222 return None -223 -224 -225def remove_compliments(expression, root=True): -226 """ -227 Removing compliments. -228 -229 A AND NOT A -> FALSE -230 A OR NOT A -> TRUE -231 """ -232 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): -233 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() -234 -235 for a, b in itertools.permutations(expression.flatten(), 2): -236 if is_complement(a, b): -237 return compliment -238 return expression -239 -240 -241def uniq_sort(expression, cache=None, root=True): -242 """ -243 Uniq and sort a connector. -244 -245 C AND A AND B AND B -> A AND B AND C -246 """ -247 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): -248 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ -249 flattened = tuple(expression.flatten()) -250 deduped = {GENERATOR.generate(e, cache): e for e in flattened} -251 arr = tuple(deduped.items()) -252 -253 # check if the operands are already sorted, if not sort them -254 # A AND C AND B -> A AND B AND C -255 for i, (sql, e) in enumerate(arr[1:]): -256 if sql < arr[i][0]: -257 expression = result_func(*(e for _, e in sorted(arr))) -258 break -259 else: -260 # we didn't have to sort but maybe we need to dedup -261 if len(deduped) < len(flattened): -262 expression = result_func(*deduped.values()) -263 -264 return expression -265 -266 -267def absorb_and_eliminate(expression, root=True): -268 """ -269 absorption: -270 A AND (A OR B) -> A -271 A OR (A AND B) -> A -272 A AND (NOT A OR B) -> A AND B -273 A OR (NOT A AND B) -> A OR B -274 elimination: -275 (A AND B) OR (A AND NOT B) -> A -276 (A OR B) AND (A OR NOT B) -> A -277 """ -278 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): -279 kind = exp.Or if isinstance(expression, exp.And) else exp.And -280 -281 for a, b in itertools.permutations(expression.flatten(), 2): -282 if isinstance(a, kind): -283 aa, ab = a.unnest_operands() -284 -285 # absorb -286 if is_complement(b, aa): -287 aa.replace(exp.true() if kind == exp.And else exp.false()) -288 elif is_complement(b, ab): -289 ab.replace(exp.true() if kind == exp.And else exp.false()) -290 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): -291 a.replace(exp.false() if kind == exp.And else exp.true()) -292 elif isinstance(b, kind): -293 # eliminate -294 rhs = b.unnest_operands() -295 ba, bb = rhs -296 -297 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): -298 a.replace(aa) -299 b.replace(aa) -300 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): -301 a.replace(ab) -302 b.replace(ab) -303 -304 return expression +197 if l.is_number and r.is_number: +198 l = float(l.name) +199 r = float(r.name) +200 elif l.is_string and r.is_string: +201 l = l.name +202 r = r.name +203 else: +204 return None +205 +206 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): +207 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): +208 return left if (av > bv if or_ else av <= bv) else right +209 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): +210 return left if (av < bv if or_ else av >= bv) else right +211 +212 # we can't ever shortcut to true because the column could be null +213 if not or_: +214 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): +215 if av <= bv: +216 return exp.false() +217 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): +218 if av >= bv: +219 return exp.false() +220 elif isinstance(a, exp.EQ): +221 if isinstance(b, exp.LT): +222 return exp.false() if av >= bv else a +223 if isinstance(b, exp.LTE): +224 return exp.false() if av > bv else a +225 if isinstance(b, exp.GT): +226 return exp.false() if av <= bv else a +227 if isinstance(b, exp.GTE): +228 return exp.false() if av < bv else a +229 if isinstance(b, exp.NEQ): +230 return exp.false() if av == bv else a +231 return None +232 +233 +234def remove_compliments(expression, root=True): +235 """ +236 Removing compliments. +237 +238 A AND NOT A -> FALSE +239 A OR NOT A -> TRUE +240 """ +241 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): +242 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() +243 +244 for a, b in itertools.permutations(expression.flatten(), 2): +245 if is_complement(a, b): +246 return compliment +247 return expression +248 +249 +250def uniq_sort(expression, cache=None, root=True): +251 """ +252 Uniq and sort a connector. +253 +254 C AND A AND B AND B -> A AND B AND C +255 """ +256 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): +257 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ +258 flattened = tuple(expression.flatten()) +259 deduped = {GENERATOR.generate(e, cache): e for e in flattened} +260 arr = tuple(deduped.items()) +261 +262 # check if the operands are already sorted, if not sort them +263 # A AND C AND B -> A AND B AND C +264 for i, (sql, e) in enumerate(arr[1:]): +265 if sql < arr[i][0]: +266 expression = result_func(*(e for _, e in sorted(arr)), copy=False) +267 break +268 else: +269 # we didn't have to sort but maybe we need to dedup +270 if len(deduped) < len(flattened): +271 expression = result_func(*deduped.values(), copy=False) +272 +273 return expression +274 +275 +276def absorb_and_eliminate(expression, root=True): +277 """ +278 absorption: +279 A AND (A OR B) -> A +280 A OR (A AND B) -> A +281 A AND (NOT A OR B) -> A AND B +282 A OR (NOT A AND B) -> A OR B +283 elimination: +284 (A AND B) OR (A AND NOT B) -> A +285 (A OR B) AND (A OR NOT B) -> A +286 """ +287 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): +288 kind = exp.Or if isinstance(expression, exp.And) else exp.And +289 +290 for a, b in itertools.permutations(expression.flatten(), 2): +291 if isinstance(a, kind): +292 aa, ab = a.unnest_operands() +293 +294 # absorb +295 if is_complement(b, aa): +296 aa.replace(exp.true() if kind == exp.And else exp.false()) +297 elif is_complement(b, ab): +298 ab.replace(exp.true() if kind == exp.And else exp.false()) +299 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): +300 a.replace(exp.false() if kind == exp.And else exp.true()) +301 elif isinstance(b, kind): +302 # eliminate +303 rhs = b.unnest_operands() +304 ba, bb = rhs 305 -306 -307def simplify_literals(expression, root=True): -308 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): -309 return _flat_simplify(expression, _simplify_binary, root) -310 elif isinstance(expression, exp.Neg): -311 this = expression.this -312 if this.is_number: -313 value = this.name -314 if value[0] == "-": -315 return exp.Literal.number(value[1:]) -316 return exp.Literal.number(f"-{value}") -317 -318 return expression -319 -320 -321def _simplify_binary(expression, a, b): -322 if isinstance(expression, exp.Is): -323 if isinstance(b, exp.Not): -324 c = b.this -325 not_ = True -326 else: -327 c = b -328 not_ = False +306 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): +307 a.replace(aa) +308 b.replace(aa) +309 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): +310 a.replace(ab) +311 b.replace(ab) +312 +313 return expression +314 +315 +316def simplify_literals(expression, root=True): +317 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): +318 return _flat_simplify(expression, _simplify_binary, root) +319 elif isinstance(expression, exp.Neg): +320 this = expression.this +321 if this.is_number: +322 value = this.name +323 if value[0] == "-": +324 return exp.Literal.number(value[1:]) +325 return exp.Literal.number(f"-{value}") +326 +327 return expression +328 329 -330 if is_null(c): -331 if isinstance(a, exp.Literal): -332 return exp.true() if not_ else exp.false() -333 if is_null(a): -334 return exp.false() if not_ else exp.true() -335 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): -336 return None -337 elif is_null(a) or is_null(b): -338 return exp.null() -339 -340 if a.is_number and b.is_number: -341 a = int(a.name) if a.is_int else Decimal(a.name) -342 b = int(b.name) if b.is_int else Decimal(b.name) -343 -344 if isinstance(expression, exp.Add): -345 return exp.Literal.number(a + b) -346 if isinstance(expression, exp.Sub): -347 return exp.Literal.number(a - b) -348 if isinstance(expression, exp.Mul): -349 return exp.Literal.number(a * b) -350 if isinstance(expression, exp.Div): -351 # engines have differing int div behavior so intdiv is not safe -352 if isinstance(a, int) and isinstance(b, int): -353 return None -354 return exp.Literal.number(a / b) -355 -356 boolean = eval_boolean(expression, a, b) -357 -358 if boolean: -359 return boolean -360 elif a.is_string and b.is_string: -361 boolean = eval_boolean(expression, a.this, b.this) -362 -363 if boolean: -364 return boolean -365 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): -366 a, b = extract_date(a), extract_interval(b) -367 if a and b: -368 if isinstance(expression, exp.Add): -369 return date_literal(a + b) -370 if isinstance(expression, exp.Sub): -371 return date_literal(a - b) -372 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): -373 a, b = extract_interval(a), extract_date(b) -374 # you cannot subtract a date from an interval -375 if a and b and isinstance(expression, exp.Add): -376 return date_literal(a + b) -377 -378 return None -379 -380 -381def simplify_parens(expression): -382 if ( -383 isinstance(expression, exp.Paren) -384 and not isinstance(expression.this, exp.Select) -385 and ( -386 not isinstance(expression.parent, (exp.Condition, exp.Binary)) -387 or isinstance(expression.this, exp.Predicate) -388 or not isinstance(expression.this, exp.Binary) -389 ) -390 ): -391 return expression.this -392 return expression -393 -394 -395def remove_where_true(expression): -396 for where in expression.find_all(exp.Where): -397 if always_true(where.this): -398 where.parent.set("where", None) -399 for join in expression.find_all(exp.Join): -400 if always_true(join.args.get("on")): -401 join.set("kind", "CROSS") -402 join.set("on", None) +330def _simplify_binary(expression, a, b): +331 if isinstance(expression, exp.Is): +332 if isinstance(b, exp.Not): +333 c = b.this +334 not_ = True +335 else: +336 c = b +337 not_ = False +338 +339 if is_null(c): +340 if isinstance(a, exp.Literal): +341 return exp.true() if not_ else exp.false() +342 if is_null(a): +343 return exp.false() if not_ else exp.true() +344 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): +345 return None +346 elif is_null(a) or is_null(b): +347 return exp.null() +348 +349 if a.is_number and b.is_number: +350 a = int(a.name) if a.is_int else Decimal(a.name) +351 b = int(b.name) if b.is_int else Decimal(b.name) +352 +353 if isinstance(expression, exp.Add): +354 return exp.Literal.number(a + b) +355 if isinstance(expression, exp.Sub): +356 return exp.Literal.number(a - b) +357 if isinstance(expression, exp.Mul): +358 return exp.Literal.number(a * b) +359 if isinstance(expression, exp.Div): +360 # engines have differing int div behavior so intdiv is not safe +361 if isinstance(a, int) and isinstance(b, int): +362 return None +363 return exp.Literal.number(a / b) +364 +365 boolean = eval_boolean(expression, a, b) +366 +367 if boolean: +368 return boolean +369 elif a.is_string and b.is_string: +370 boolean = eval_boolean(expression, a.this, b.this) +371 +372 if boolean: +373 return boolean +374 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): +375 a, b = extract_date(a), extract_interval(b) +376 if a and b: +377 if isinstance(expression, exp.Add): +378 return date_literal(a + b) +379 if isinstance(expression, exp.Sub): +380 return date_literal(a - b) +381 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): +382 a, b = extract_interval(a), extract_date(b) +383 # you cannot subtract a date from an interval +384 if a and b and isinstance(expression, exp.Add): +385 return date_literal(a + b) +386 +387 return None +388 +389 +390def simplify_parens(expression): +391 if ( +392 isinstance(expression, exp.Paren) +393 and not isinstance(expression.this, exp.Select) +394 and ( +395 not isinstance(expression.parent, (exp.Condition, exp.Binary)) +396 or isinstance(expression.this, exp.Predicate) +397 or not isinstance(expression.this, exp.Binary) +398 ) +399 ): +400 return expression.this +401 return expression +402 403 -404 -405def always_true(expression): -406 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( -407 expression, exp.Literal -408 ) -409 -410 -411def is_complement(a, b): -412 return isinstance(b, exp.Not) and b.this == a +404def remove_where_true(expression): +405 for where in expression.find_all(exp.Where): +406 if always_true(where.this): +407 where.parent.set("where", None) +408 for join in expression.find_all(exp.Join): +409 if always_true(join.args.get("on")): +410 join.set("kind", "CROSS") +411 join.set("on", None) +412 413 -414 -415def is_false(a: exp.Expression) -> bool: -416 return type(a) is exp.Boolean and not a.this -417 +414def always_true(expression): +415 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( +416 expression, exp.Literal +417 ) 418 -419def is_null(a: exp.Expression) -> bool: -420 return type(a) is exp.Null -421 +419 +420def is_complement(a, b): +421 return isinstance(b, exp.Not) and b.this == a 422 -423def eval_boolean(expression, a, b): -424 if isinstance(expression, (exp.EQ, exp.Is)): -425 return boolean_literal(a == b) -426 if isinstance(expression, exp.NEQ): -427 return boolean_literal(a != b) -428 if isinstance(expression, exp.GT): -429 return boolean_literal(a > b) -430 if isinstance(expression, exp.GTE): -431 return boolean_literal(a >= b) -432 if isinstance(expression, exp.LT): -433 return boolean_literal(a < b) -434 if isinstance(expression, exp.LTE): -435 return boolean_literal(a <= b) -436 return None -437 -438 -439def extract_date(cast): -440 # The "fromisoformat" conversion could fail if the cast is used on an identifier, -441 # so in that case we can't extract the date. -442 try: -443 if cast.args["to"].this == exp.DataType.Type.DATE: -444 return datetime.date.fromisoformat(cast.name) -445 if cast.args["to"].this == exp.DataType.Type.DATETIME: -446 return datetime.datetime.fromisoformat(cast.name) -447 except ValueError: -448 return None -449 -450 -451def extract_interval(interval): -452 try: -453 from dateutil.relativedelta import relativedelta # type: ignore -454 except ModuleNotFoundError: -455 return None -456 -457 n = int(interval.name) -458 unit = interval.text("unit").lower() +423 +424def is_false(a: exp.Expression) -> bool: +425 return type(a) is exp.Boolean and not a.this +426 +427 +428def is_null(a: exp.Expression) -> bool: +429 return type(a) is exp.Null +430 +431 +432def eval_boolean(expression, a, b): +433 if isinstance(expression, (exp.EQ, exp.Is)): +434 return boolean_literal(a == b) +435 if isinstance(expression, exp.NEQ): +436 return boolean_literal(a != b) +437 if isinstance(expression, exp.GT): +438 return boolean_literal(a > b) +439 if isinstance(expression, exp.GTE): +440 return boolean_literal(a >= b) +441 if isinstance(expression, exp.LT): +442 return boolean_literal(a < b) +443 if isinstance(expression, exp.LTE): +444 return boolean_literal(a <= b) +445 return None +446 +447 +448def extract_date(cast): +449 # The "fromisoformat" conversion could fail if the cast is used on an identifier, +450 # so in that case we can't extract the date. +451 try: +452 if cast.args["to"].this == exp.DataType.Type.DATE: +453 return datetime.date.fromisoformat(cast.name) +454 if cast.args["to"].this == exp.DataType.Type.DATETIME: +455 return datetime.datetime.fromisoformat(cast.name) +456 except ValueError: +457 return None +458 459 -460 if unit == "year": -461 return relativedelta(years=n) -462 if unit == "month": -463 return relativedelta(months=n) -464 if unit == "week": -465 return relativedelta(weeks=n) -466 if unit == "day": -467 return relativedelta(days=n) -468 return None -469 -470 -471def date_literal(date): -472 return exp.cast( -473 exp.Literal.string(date), -474 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", -475 ) -476 -477 -478def boolean_literal(condition): -479 return exp.true() if condition else exp.false() -480 -481 -482def _flat_simplify(expression, simplifier, root=True): -483 if root or not expression.same_parent: -484 operands = [] -485 queue = deque(expression.flatten(unnest=False)) -486 size = len(queue) -487 -488 while queue: -489 a = queue.popleft() +460def extract_interval(interval): +461 try: +462 from dateutil.relativedelta import relativedelta # type: ignore +463 except ModuleNotFoundError: +464 return None +465 +466 n = int(interval.name) +467 unit = interval.text("unit").lower() +468 +469 if unit == "year": +470 return relativedelta(years=n) +471 if unit == "month": +472 return relativedelta(months=n) +473 if unit == "week": +474 return relativedelta(weeks=n) +475 if unit == "day": +476 return relativedelta(days=n) +477 return None +478 +479 +480def date_literal(date): +481 return exp.cast( +482 exp.Literal.string(date), +483 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", +484 ) +485 +486 +487def boolean_literal(condition): +488 return exp.true() if condition else exp.false() +489 490 -491 for b in queue: -492 result = simplifier(expression, a, b) -493 -494 if result: -495 queue.remove(b) -496 queue.appendleft(result) -497 break -498 else: -499 operands.append(a) -500 -501 if len(operands) < size: -502 return functools.reduce( -503 lambda a, b: expression.__class__(this=a, expression=b), operands -504 ) -505 return expression +491def _flat_simplify(expression, simplifier, root=True): +492 if root or not expression.same_parent: +493 operands = [] +494 queue = deque(expression.flatten(unnest=False)) +495 size = len(queue) +496 +497 while queue: +498 a = queue.popleft() +499 +500 for b in queue: +501 result = simplifier(expression, a, b) +502 +503 if result: +504 queue.remove(b) +505 queue.appendleft(result) +506 break +507 else: +508 operands.append(a) +509 +510 if len(operands) < size: +511 return functools.reduce( +512 lambda a, b: expression.__class__(this=a, expression=b), operands +513 ) +514 return expression @@ -723,8 +732,9 @@ 61 return exp.and_( 62 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 63 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), -64 ) -65 return expression +64 copy=False, +65 ) +66 return expression @@ -746,32 +756,40 @@ -
68def simplify_not(expression):
-69    """
-70    Demorgan's Law
-71    NOT (x OR y) -> NOT x AND NOT y
-72    NOT (x AND y) -> NOT x OR NOT y
-73    """
-74    if isinstance(expression, exp.Not):
-75        if is_null(expression.this):
-76            return exp.null()
-77        if isinstance(expression.this, exp.Paren):
-78            condition = expression.this.unnest()
-79            if isinstance(condition, exp.And):
-80                return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
-81            if isinstance(condition, exp.Or):
-82                return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
-83            if is_null(condition):
-84                return exp.null()
-85        if always_true(expression.this):
-86            return exp.false()
-87        if is_false(expression.this):
-88            return exp.true()
-89        if isinstance(expression.this, exp.Not):
-90            # double negation
-91            # NOT NOT x -> x
-92            return expression.this.this
-93    return expression
+            
 69def simplify_not(expression):
+ 70    """
+ 71    Demorgan's Law
+ 72    NOT (x OR y) -> NOT x AND NOT y
+ 73    NOT (x AND y) -> NOT x OR NOT y
+ 74    """
+ 75    if isinstance(expression, exp.Not):
+ 76        if is_null(expression.this):
+ 77            return exp.null()
+ 78        if isinstance(expression.this, exp.Paren):
+ 79            condition = expression.this.unnest()
+ 80            if isinstance(condition, exp.And):
+ 81                return exp.or_(
+ 82                    exp.not_(condition.left, copy=False),
+ 83                    exp.not_(condition.right, copy=False),
+ 84                    copy=False,
+ 85                )
+ 86            if isinstance(condition, exp.Or):
+ 87                return exp.and_(
+ 88                    exp.not_(condition.left, copy=False),
+ 89                    exp.not_(condition.right, copy=False),
+ 90                    copy=False,
+ 91                )
+ 92            if is_null(condition):
+ 93                return exp.null()
+ 94        if always_true(expression.this):
+ 95            return exp.false()
+ 96        if is_false(expression.this):
+ 97            return exp.true()
+ 98        if isinstance(expression.this, exp.Not):
+ 99            # double negation
+100            # NOT NOT x -> x
+101            return expression.this.this
+102    return expression
 
@@ -793,17 +811,17 @@ NOT (x AND y) -> NOT x OR NOT y

-
 96def flatten(expression):
- 97    """
- 98    A AND (B AND C) -> A AND B AND C
- 99    A OR (B OR C) -> A OR B OR C
-100    """
-101    if isinstance(expression, exp.Connector):
-102        for node in expression.args.values():
-103            child = node.unnest()
-104            if isinstance(child, expression.__class__):
-105                node.replace(child)
-106    return expression
+            
105def flatten(expression):
+106    """
+107    A AND (B AND C) -> A AND B AND C
+108    A OR (B OR C) -> A OR B OR C
+109    """
+110    if isinstance(expression, exp.Connector):
+111        for node in expression.args.values():
+112            child = node.unnest()
+113            if isinstance(child, expression.__class__):
+114                node.replace(child)
+115    return expression
 
@@ -824,42 +842,42 @@ A OR (B OR C) -> A OR B OR C

-
109def simplify_connectors(expression, root=True):
-110    def _simplify_connectors(expression, left, right):
-111        if left == right:
-112            return left
-113        if isinstance(expression, exp.And):
-114            if is_false(left) or is_false(right):
-115                return exp.false()
-116            if is_null(left) or is_null(right):
-117                return exp.null()
-118            if always_true(left) and always_true(right):
-119                return exp.true()
-120            if always_true(left):
-121                return right
-122            if always_true(right):
-123                return left
-124            return _simplify_comparison(expression, left, right)
-125        elif isinstance(expression, exp.Or):
-126            if always_true(left) or always_true(right):
-127                return exp.true()
-128            if is_false(left) and is_false(right):
-129                return exp.false()
-130            if (
-131                (is_null(left) and is_null(right))
-132                or (is_null(left) and is_false(right))
-133                or (is_false(left) and is_null(right))
-134            ):
-135                return exp.null()
-136            if is_false(left):
-137                return right
-138            if is_false(right):
-139                return left
-140            return _simplify_comparison(expression, left, right, or_=True)
-141
-142    if isinstance(expression, exp.Connector):
-143        return _flat_simplify(expression, _simplify_connectors, root)
-144    return expression
+            
118def simplify_connectors(expression, root=True):
+119    def _simplify_connectors(expression, left, right):
+120        if left == right:
+121            return left
+122        if isinstance(expression, exp.And):
+123            if is_false(left) or is_false(right):
+124                return exp.false()
+125            if is_null(left) or is_null(right):
+126                return exp.null()
+127            if always_true(left) and always_true(right):
+128                return exp.true()
+129            if always_true(left):
+130                return right
+131            if always_true(right):
+132                return left
+133            return _simplify_comparison(expression, left, right)
+134        elif isinstance(expression, exp.Or):
+135            if always_true(left) or always_true(right):
+136                return exp.true()
+137            if is_false(left) and is_false(right):
+138                return exp.false()
+139            if (
+140                (is_null(left) and is_null(right))
+141                or (is_null(left) and is_false(right))
+142                or (is_false(left) and is_null(right))
+143            ):
+144                return exp.null()
+145            if is_false(left):
+146                return right
+147            if is_false(right):
+148                return left
+149            return _simplify_comparison(expression, left, right, or_=True)
+150
+151    if isinstance(expression, exp.Connector):
+152        return _flat_simplify(expression, _simplify_connectors, root)
+153    return expression
 
@@ -877,20 +895,20 @@ A OR (B OR C) -> A OR B OR C

-
226def remove_compliments(expression, root=True):
-227    """
-228    Removing compliments.
-229
-230    A AND NOT A -> FALSE
-231    A OR NOT A -> TRUE
-232    """
-233    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
-234        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
-235
-236        for a, b in itertools.permutations(expression.flatten(), 2):
-237            if is_complement(a, b):
-238                return compliment
-239    return expression
+            
235def remove_compliments(expression, root=True):
+236    """
+237    Removing compliments.
+238
+239    A AND NOT A -> FALSE
+240    A OR NOT A -> TRUE
+241    """
+242    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
+243        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
+244
+245        for a, b in itertools.permutations(expression.flatten(), 2):
+246            if is_complement(a, b):
+247                return compliment
+248    return expression
 
@@ -913,30 +931,30 @@ A OR NOT A -> TRUE

-
242def uniq_sort(expression, cache=None, root=True):
-243    """
-244    Uniq and sort a connector.
-245
-246    C AND A AND B AND B -> A AND B AND C
-247    """
-248    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
-249        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
-250        flattened = tuple(expression.flatten())
-251        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
-252        arr = tuple(deduped.items())
-253
-254        # check if the operands are already sorted, if not sort them
-255        # A AND C AND B -> A AND B AND C
-256        for i, (sql, e) in enumerate(arr[1:]):
-257            if sql < arr[i][0]:
-258                expression = result_func(*(e for _, e in sorted(arr)))
-259                break
-260        else:
-261            # we didn't have to sort but maybe we need to dedup
-262            if len(deduped) < len(flattened):
-263                expression = result_func(*deduped.values())
-264
-265    return expression
+            
251def uniq_sort(expression, cache=None, root=True):
+252    """
+253    Uniq and sort a connector.
+254
+255    C AND A AND B AND B -> A AND B AND C
+256    """
+257    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
+258        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
+259        flattened = tuple(expression.flatten())
+260        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
+261        arr = tuple(deduped.items())
+262
+263        # check if the operands are already sorted, if not sort them
+264        # A AND C AND B -> A AND B AND C
+265        for i, (sql, e) in enumerate(arr[1:]):
+266            if sql < arr[i][0]:
+267                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
+268                break
+269        else:
+270            # we didn't have to sort but maybe we need to dedup
+271            if len(deduped) < len(flattened):
+272                expression = result_func(*deduped.values(), copy=False)
+273
+274    return expression
 
@@ -958,44 +976,44 @@ A OR NOT A -> TRUE

-
268def absorb_and_eliminate(expression, root=True):
-269    """
-270    absorption:
-271        A AND (A OR B) -> A
-272        A OR (A AND B) -> A
-273        A AND (NOT A OR B) -> A AND B
-274        A OR (NOT A AND B) -> A OR B
-275    elimination:
-276        (A AND B) OR (A AND NOT B) -> A
-277        (A OR B) AND (A OR NOT B) -> A
-278    """
-279    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
-280        kind = exp.Or if isinstance(expression, exp.And) else exp.And
-281
-282        for a, b in itertools.permutations(expression.flatten(), 2):
-283            if isinstance(a, kind):
-284                aa, ab = a.unnest_operands()
-285
-286                # absorb
-287                if is_complement(b, aa):
-288                    aa.replace(exp.true() if kind == exp.And else exp.false())
-289                elif is_complement(b, ab):
-290                    ab.replace(exp.true() if kind == exp.And else exp.false())
-291                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
-292                    a.replace(exp.false() if kind == exp.And else exp.true())
-293                elif isinstance(b, kind):
-294                    # eliminate
-295                    rhs = b.unnest_operands()
-296                    ba, bb = rhs
-297
-298                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
-299                        a.replace(aa)
-300                        b.replace(aa)
-301                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
-302                        a.replace(ab)
-303                        b.replace(ab)
-304
-305    return expression
+            
277def absorb_and_eliminate(expression, root=True):
+278    """
+279    absorption:
+280        A AND (A OR B) -> A
+281        A OR (A AND B) -> A
+282        A AND (NOT A OR B) -> A AND B
+283        A OR (NOT A AND B) -> A OR B
+284    elimination:
+285        (A AND B) OR (A AND NOT B) -> A
+286        (A OR B) AND (A OR NOT B) -> A
+287    """
+288    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
+289        kind = exp.Or if isinstance(expression, exp.And) else exp.And
+290
+291        for a, b in itertools.permutations(expression.flatten(), 2):
+292            if isinstance(a, kind):
+293                aa, ab = a.unnest_operands()
+294
+295                # absorb
+296                if is_complement(b, aa):
+297                    aa.replace(exp.true() if kind == exp.And else exp.false())
+298                elif is_complement(b, ab):
+299                    ab.replace(exp.true() if kind == exp.And else exp.false())
+300                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
+301                    a.replace(exp.false() if kind == exp.And else exp.true())
+302                elif isinstance(b, kind):
+303                    # eliminate
+304                    rhs = b.unnest_operands()
+305                    ba, bb = rhs
+306
+307                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
+308                        a.replace(aa)
+309                        b.replace(aa)
+310                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
+311                        a.replace(ab)
+312                        b.replace(ab)
+313
+314    return expression
 
@@ -1022,18 +1040,18 @@ elimination:
-
308def simplify_literals(expression, root=True):
-309    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
-310        return _flat_simplify(expression, _simplify_binary, root)
-311    elif isinstance(expression, exp.Neg):
-312        this = expression.this
-313        if this.is_number:
-314            value = this.name
-315            if value[0] == "-":
-316                return exp.Literal.number(value[1:])
-317            return exp.Literal.number(f"-{value}")
-318
-319    return expression
+            
317def simplify_literals(expression, root=True):
+318    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
+319        return _flat_simplify(expression, _simplify_binary, root)
+320    elif isinstance(expression, exp.Neg):
+321        this = expression.this
+322        if this.is_number:
+323            value = this.name
+324            if value[0] == "-":
+325                return exp.Literal.number(value[1:])
+326            return exp.Literal.number(f"-{value}")
+327
+328    return expression
 
@@ -1051,18 +1069,18 @@ elimination:
-
382def simplify_parens(expression):
-383    if (
-384        isinstance(expression, exp.Paren)
-385        and not isinstance(expression.this, exp.Select)
-386        and (
-387            not isinstance(expression.parent, (exp.Condition, exp.Binary))
-388            or isinstance(expression.this, exp.Predicate)
-389            or not isinstance(expression.this, exp.Binary)
-390        )
-391    ):
-392        return expression.this
-393    return expression
+            
391def simplify_parens(expression):
+392    if (
+393        isinstance(expression, exp.Paren)
+394        and not isinstance(expression.this, exp.Select)
+395        and (
+396            not isinstance(expression.parent, (exp.Condition, exp.Binary))
+397            or isinstance(expression.this, exp.Predicate)
+398            or not isinstance(expression.this, exp.Binary)
+399        )
+400    ):
+401        return expression.this
+402    return expression
 
@@ -1080,14 +1098,14 @@ elimination:
-
396def remove_where_true(expression):
-397    for where in expression.find_all(exp.Where):
-398        if always_true(where.this):
-399            where.parent.set("where", None)
-400    for join in expression.find_all(exp.Join):
-401        if always_true(join.args.get("on")):
-402            join.set("kind", "CROSS")
-403            join.set("on", None)
+            
405def remove_where_true(expression):
+406    for where in expression.find_all(exp.Where):
+407        if always_true(where.this):
+408            where.parent.set("where", None)
+409    for join in expression.find_all(exp.Join):
+410        if always_true(join.args.get("on")):
+411            join.set("kind", "CROSS")
+412            join.set("on", None)
 
@@ -1105,10 +1123,10 @@ elimination:
-
406def always_true(expression):
-407    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
-408        expression, exp.Literal
-409    )
+            
415def always_true(expression):
+416    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
+417        expression, exp.Literal
+418    )
 
@@ -1126,8 +1144,8 @@ elimination:
-
412def is_complement(a, b):
-413    return isinstance(b, exp.Not) and b.this == a
+            
421def is_complement(a, b):
+422    return isinstance(b, exp.Not) and b.this == a
 
@@ -1145,8 +1163,8 @@ elimination:
-
416def is_false(a: exp.Expression) -> bool:
-417    return type(a) is exp.Boolean and not a.this
+            
425def is_false(a: exp.Expression) -> bool:
+426    return type(a) is exp.Boolean and not a.this
 
@@ -1164,8 +1182,8 @@ elimination:
-
420def is_null(a: exp.Expression) -> bool:
-421    return type(a) is exp.Null
+            
429def is_null(a: exp.Expression) -> bool:
+430    return type(a) is exp.Null
 
@@ -1183,20 +1201,20 @@ elimination:
-
424def eval_boolean(expression, a, b):
-425    if isinstance(expression, (exp.EQ, exp.Is)):
-426        return boolean_literal(a == b)
-427    if isinstance(expression, exp.NEQ):
-428        return boolean_literal(a != b)
-429    if isinstance(expression, exp.GT):
-430        return boolean_literal(a > b)
-431    if isinstance(expression, exp.GTE):
-432        return boolean_literal(a >= b)
-433    if isinstance(expression, exp.LT):
-434        return boolean_literal(a < b)
-435    if isinstance(expression, exp.LTE):
-436        return boolean_literal(a <= b)
-437    return None
+            
433def eval_boolean(expression, a, b):
+434    if isinstance(expression, (exp.EQ, exp.Is)):
+435        return boolean_literal(a == b)
+436    if isinstance(expression, exp.NEQ):
+437        return boolean_literal(a != b)
+438    if isinstance(expression, exp.GT):
+439        return boolean_literal(a > b)
+440    if isinstance(expression, exp.GTE):
+441        return boolean_literal(a >= b)
+442    if isinstance(expression, exp.LT):
+443        return boolean_literal(a < b)
+444    if isinstance(expression, exp.LTE):
+445        return boolean_literal(a <= b)
+446    return None
 
@@ -1214,16 +1232,16 @@ elimination:
-
440def extract_date(cast):
-441    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
-442    # so in that case we can't extract the date.
-443    try:
-444        if cast.args["to"].this == exp.DataType.Type.DATE:
-445            return datetime.date.fromisoformat(cast.name)
-446        if cast.args["to"].this == exp.DataType.Type.DATETIME:
-447            return datetime.datetime.fromisoformat(cast.name)
-448    except ValueError:
-449        return None
+            
449def extract_date(cast):
+450    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
+451    # so in that case we can't extract the date.
+452    try:
+453        if cast.args["to"].this == exp.DataType.Type.DATE:
+454            return datetime.date.fromisoformat(cast.name)
+455        if cast.args["to"].this == exp.DataType.Type.DATETIME:
+456            return datetime.datetime.fromisoformat(cast.name)
+457    except ValueError:
+458        return None
 
@@ -1241,24 +1259,24 @@ elimination:
-
452def extract_interval(interval):
-453    try:
-454        from dateutil.relativedelta import relativedelta  # type: ignore
-455    except ModuleNotFoundError:
-456        return None
-457
-458    n = int(interval.name)
-459    unit = interval.text("unit").lower()
-460
-461    if unit == "year":
-462        return relativedelta(years=n)
-463    if unit == "month":
-464        return relativedelta(months=n)
-465    if unit == "week":
-466        return relativedelta(weeks=n)
-467    if unit == "day":
-468        return relativedelta(days=n)
-469    return None
+            
461def extract_interval(interval):
+462    try:
+463        from dateutil.relativedelta import relativedelta  # type: ignore
+464    except ModuleNotFoundError:
+465        return None
+466
+467    n = int(interval.name)
+468    unit = interval.text("unit").lower()
+469
+470    if unit == "year":
+471        return relativedelta(years=n)
+472    if unit == "month":
+473        return relativedelta(months=n)
+474    if unit == "week":
+475        return relativedelta(weeks=n)
+476    if unit == "day":
+477        return relativedelta(days=n)
+478    return None
 
@@ -1276,11 +1294,11 @@ elimination:
-
472def date_literal(date):
-473    return exp.cast(
-474        exp.Literal.string(date),
-475        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
-476    )
+            
481def date_literal(date):
+482    return exp.cast(
+483        exp.Literal.string(date),
+484        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
+485    )
 
@@ -1298,8 +1316,8 @@ elimination:
-
479def boolean_literal(condition):
-480    return exp.true() if condition else exp.false()
+            
488def boolean_literal(condition):
+489    return exp.true() if condition else exp.false()
 
-- cgit v1.2.3