sqlglot.optimizer.optimize_joins
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.helper import tsort 7 8JOIN_ATTRS = ("on", "side", "kind", "using", "method") 9 10 11def optimize_joins(expression): 12 """ 13 Removes cross joins if possible and reorder joins based on predicate dependencies. 14 15 Example: 16 >>> from sqlglot import parse_one 17 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 18 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 19 """ 20 21 for select in expression.find_all(exp.Select): 22 references = {} 23 cross_joins = [] 24 25 for join in select.args.get("joins", []): 26 tables = other_table_names(join) 27 28 if tables: 29 for table in tables: 30 references[table] = references.get(table, []) + [join] 31 else: 32 cross_joins.append((join.alias_or_name, join)) 33 34 for name, join in cross_joins: 35 for dep in references.get(name, []): 36 on = dep.args["on"] 37 38 if isinstance(on, exp.Connector): 39 if len(other_table_names(dep)) < 2: 40 continue 41 42 for predicate in on.flatten(): 43 if name in exp.column_table_names(predicate): 44 predicate.replace(exp.true()) 45 join.on(predicate, copy=False) 46 47 expression = reorder_joins(expression) 48 expression = normalize(expression) 49 return expression 50 51 52def reorder_joins(expression): 53 """ 54 Reorder joins by topological sort order based on predicate references. 55 """ 56 for from_ in expression.find_all(exp.From): 57 parent = from_.parent 58 joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} 59 dag = {name: other_table_names(join) for name, join in joins.items()} 60 parent.set( 61 "joins", 62 [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], 63 ) 64 return expression 65 66 67def normalize(expression): 68 """ 69 Remove INNER and OUTER from joins as they are optional. 70 """ 71 for join in expression.find_all(exp.Join): 72 if not any(join.args.get(k) for k in JOIN_ATTRS): 73 join.set("kind", "CROSS") 74 75 if join.kind == "CROSS": 76 join.set("on", None) 77 else: 78 join.set("kind", None) 79 80 if not join.args.get("on") and not join.args.get("using"): 81 join.set("on", exp.true()) 82 return expression 83 84 85def other_table_names(join: exp.Join) -> t.Set[str]: 86 on = join.args.get("on") 87 return exp.column_table_names(on, join.alias_or_name) if on else set()
JOIN_ATTRS =
('on', 'side', 'kind', 'using', 'method')
def
optimize_joins(expression):
12def optimize_joins(expression): 13 """ 14 Removes cross joins if possible and reorder joins based on predicate dependencies. 15 16 Example: 17 >>> from sqlglot import parse_one 18 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 19 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 20 """ 21 22 for select in expression.find_all(exp.Select): 23 references = {} 24 cross_joins = [] 25 26 for join in select.args.get("joins", []): 27 tables = other_table_names(join) 28 29 if tables: 30 for table in tables: 31 references[table] = references.get(table, []) + [join] 32 else: 33 cross_joins.append((join.alias_or_name, join)) 34 35 for name, join in cross_joins: 36 for dep in references.get(name, []): 37 on = dep.args["on"] 38 39 if isinstance(on, exp.Connector): 40 if len(other_table_names(dep)) < 2: 41 continue 42 43 for predicate in on.flatten(): 44 if name in exp.column_table_names(predicate): 45 predicate.replace(exp.true()) 46 join.on(predicate, copy=False) 47 48 expression = reorder_joins(expression) 49 expression = normalize(expression) 50 return expression
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def
reorder_joins(expression):
53def reorder_joins(expression): 54 """ 55 Reorder joins by topological sort order based on predicate references. 56 """ 57 for from_ in expression.find_all(exp.From): 58 parent = from_.parent 59 joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} 60 dag = {name: other_table_names(join) for name, join in joins.items()} 61 parent.set( 62 "joins", 63 [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], 64 ) 65 return expression
Reorder joins by topological sort order based on predicate references.
def
normalize(expression):
68def normalize(expression): 69 """ 70 Remove INNER and OUTER from joins as they are optional. 71 """ 72 for join in expression.find_all(exp.Join): 73 if not any(join.args.get(k) for k in JOIN_ATTRS): 74 join.set("kind", "CROSS") 75 76 if join.kind == "CROSS": 77 join.set("on", None) 78 else: 79 join.set("kind", None) 80 81 if not join.args.get("on") and not join.args.get("using"): 82 join.set("on", exp.true()) 83 return expression
Remove INNER and OUTER from joins as they are optional.