sqlglot.optimizer.optimize_joins
1from sqlglot import exp 2from sqlglot.helper import tsort 3 4JOIN_ATTRS = ("on", "side", "kind", "using", "natural") 5 6 7def optimize_joins(expression): 8 """ 9 Removes cross joins if possible and reorder joins based on predicate dependencies. 10 11 Example: 12 >>> from sqlglot import parse_one 13 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 14 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 15 """ 16 for select in expression.find_all(exp.Select): 17 references = {} 18 cross_joins = [] 19 20 for join in select.args.get("joins", []): 21 name = join.this.alias_or_name 22 tables = other_table_names(join, name) 23 24 if tables: 25 for table in tables: 26 references[table] = references.get(table, []) + [join] 27 else: 28 cross_joins.append((name, join)) 29 30 for name, join in cross_joins: 31 for dep in references.get(name, []): 32 on = dep.args["on"] 33 34 if isinstance(on, exp.Connector): 35 for predicate in on.flatten(): 36 if name in exp.column_table_names(predicate): 37 predicate.replace(exp.true()) 38 join.on(predicate, copy=False) 39 40 expression = reorder_joins(expression) 41 expression = normalize(expression) 42 return expression 43 44 45def reorder_joins(expression): 46 """ 47 Reorder joins by topological sort order based on predicate references. 48 """ 49 for from_ in expression.find_all(exp.From): 50 head = from_.this 51 parent = from_.parent 52 joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} 53 dag = {head.alias_or_name: []} 54 55 for name, join in joins.items(): 56 dag[name] = other_table_names(join, name) 57 58 parent.set( 59 "joins", 60 [joins[name] for name in tsort(dag) if name != head.alias_or_name], 61 ) 62 return expression 63 64 65def normalize(expression): 66 """ 67 Remove INNER and OUTER from joins as they are optional. 68 """ 69 for join in expression.find_all(exp.Join): 70 if not any(join.args.get(k) for k in JOIN_ATTRS): 71 join.set("kind", "CROSS") 72 73 if join.kind != "CROSS": 74 join.set("kind", None) 75 return expression 76 77 78def other_table_names(join, exclude): 79 return [ 80 name 81 for name in (exp.column_table_names(join.args.get("on") or exp.true())) 82 if name != exclude 83 ]
def
optimize_joins(expression):
8def optimize_joins(expression): 9 """ 10 Removes cross joins if possible and reorder joins based on predicate dependencies. 11 12 Example: 13 >>> from sqlglot import parse_one 14 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 15 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 16 """ 17 for select in expression.find_all(exp.Select): 18 references = {} 19 cross_joins = [] 20 21 for join in select.args.get("joins", []): 22 name = join.this.alias_or_name 23 tables = other_table_names(join, name) 24 25 if tables: 26 for table in tables: 27 references[table] = references.get(table, []) + [join] 28 else: 29 cross_joins.append((name, join)) 30 31 for name, join in cross_joins: 32 for dep in references.get(name, []): 33 on = dep.args["on"] 34 35 if isinstance(on, exp.Connector): 36 for predicate in on.flatten(): 37 if name in exp.column_table_names(predicate): 38 predicate.replace(exp.true()) 39 join.on(predicate, copy=False) 40 41 expression = reorder_joins(expression) 42 expression = normalize(expression) 43 return expression
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def
reorder_joins(expression):
46def reorder_joins(expression): 47 """ 48 Reorder joins by topological sort order based on predicate references. 49 """ 50 for from_ in expression.find_all(exp.From): 51 head = from_.this 52 parent = from_.parent 53 joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} 54 dag = {head.alias_or_name: []} 55 56 for name, join in joins.items(): 57 dag[name] = other_table_names(join, name) 58 59 parent.set( 60 "joins", 61 [joins[name] for name in tsort(dag) if name != head.alias_or_name], 62 ) 63 return expression
Reorder joins by topological sort order based on predicate references.
def
normalize(expression):
66def normalize(expression): 67 """ 68 Remove INNER and OUTER from joins as they are optional. 69 """ 70 for join in expression.find_all(exp.Join): 71 if not any(join.args.get(k) for k in JOIN_ATTRS): 72 join.set("kind", "CROSS") 73 74 if join.kind != "CROSS": 75 join.set("kind", None) 76 return expression
Remove INNER and OUTER from joins as they are optional.
def
other_table_names(join, exclude):