sqlglot.optimizer.optimize_joins
1from sqlglot import exp 2from sqlglot.helper import tsort 3 4 5def optimize_joins(expression): 6 """ 7 Removes cross joins if possible and reorder joins based on predicate dependencies. 8 9 Example: 10 >>> from sqlglot import parse_one 11 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 12 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 13 """ 14 for select in expression.find_all(exp.Select): 15 references = {} 16 cross_joins = [] 17 18 for join in select.args.get("joins", []): 19 name = join.this.alias_or_name 20 tables = other_table_names(join, name) 21 22 if tables: 23 for table in tables: 24 references[table] = references.get(table, []) + [join] 25 else: 26 cross_joins.append((name, join)) 27 28 for name, join in cross_joins: 29 for dep in references.get(name, []): 30 on = dep.args["on"] 31 32 if isinstance(on, exp.Connector): 33 for predicate in on.flatten(): 34 if name in exp.column_table_names(predicate): 35 predicate.replace(exp.true()) 36 join.on(predicate, copy=False) 37 38 expression = reorder_joins(expression) 39 expression = normalize(expression) 40 return expression 41 42 43def reorder_joins(expression): 44 """ 45 Reorder joins by topological sort order based on predicate references. 46 """ 47 for from_ in expression.find_all(exp.From): 48 head = from_.expressions[0] 49 parent = from_.parent 50 joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} 51 dag = {head.alias_or_name: []} 52 53 for name, join in joins.items(): 54 dag[name] = other_table_names(join, name) 55 56 parent.set( 57 "joins", 58 [joins[name] for name in tsort(dag) if name != head.alias_or_name], 59 ) 60 return expression 61 62 63def normalize(expression): 64 """ 65 Remove INNER and OUTER from joins as they are optional. 66 """ 67 for join in expression.find_all(exp.Join): 68 if join.kind != "CROSS": 69 join.set("kind", None) 70 return expression 71 72 73def other_table_names(join, exclude): 74 return [ 75 name 76 for name in (exp.column_table_names(join.args.get("on") or exp.true())) 77 if name != exclude 78 ]
def
optimize_joins(expression):
6def optimize_joins(expression): 7 """ 8 Removes cross joins if possible and reorder joins based on predicate dependencies. 9 10 Example: 11 >>> from sqlglot import parse_one 12 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 13 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 14 """ 15 for select in expression.find_all(exp.Select): 16 references = {} 17 cross_joins = [] 18 19 for join in select.args.get("joins", []): 20 name = join.this.alias_or_name 21 tables = other_table_names(join, name) 22 23 if tables: 24 for table in tables: 25 references[table] = references.get(table, []) + [join] 26 else: 27 cross_joins.append((name, join)) 28 29 for name, join in cross_joins: 30 for dep in references.get(name, []): 31 on = dep.args["on"] 32 33 if isinstance(on, exp.Connector): 34 for predicate in on.flatten(): 35 if name in exp.column_table_names(predicate): 36 predicate.replace(exp.true()) 37 join.on(predicate, copy=False) 38 39 expression = reorder_joins(expression) 40 expression = normalize(expression) 41 return expression
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def
reorder_joins(expression):
44def reorder_joins(expression): 45 """ 46 Reorder joins by topological sort order based on predicate references. 47 """ 48 for from_ in expression.find_all(exp.From): 49 head = from_.expressions[0] 50 parent = from_.parent 51 joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} 52 dag = {head.alias_or_name: []} 53 54 for name, join in joins.items(): 55 dag[name] = other_table_names(join, name) 56 57 parent.set( 58 "joins", 59 [joins[name] for name in tsort(dag) if name != head.alias_or_name], 60 ) 61 return expression
Reorder joins by topological sort order based on predicate references.
def
normalize(expression):
64def normalize(expression): 65 """ 66 Remove INNER and OUTER from joins as they are optional. 67 """ 68 for join in expression.find_all(exp.Join): 69 if join.kind != "CROSS": 70 join.set("kind", None) 71 return expression
Remove INNER and OUTER from joins as they are optional.
def
other_table_names(join, exclude):