sqlglot.optimizer.optimize_joins
1from sqlglot import exp 2from sqlglot.helper import tsort 3from sqlglot.optimizer.simplify import simplify 4 5 6def optimize_joins(expression): 7 """ 8 Removes cross joins if possible and reorder joins based on predicate dependencies. 9 10 Example: 11 >>> from sqlglot import parse_one 12 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 13 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 14 """ 15 for select in expression.find_all(exp.Select): 16 references = {} 17 cross_joins = [] 18 19 for join in select.args.get("joins", []): 20 name = join.this.alias_or_name 21 tables = other_table_names(join, name) 22 23 if tables: 24 for table in tables: 25 references[table] = references.get(table, []) + [join] 26 else: 27 cross_joins.append((name, join)) 28 29 for name, join in cross_joins: 30 for dep in references.get(name, []): 31 on = dep.args["on"] 32 on = on.replace(simplify(on)) 33 34 if isinstance(on, exp.Connector): 35 for predicate in on.flatten(): 36 if name in exp.column_table_names(predicate): 37 predicate.replace(exp.true()) 38 join.on(predicate, copy=False) 39 40 expression = reorder_joins(expression) 41 expression = normalize(expression) 42 return expression 43 44 45def reorder_joins(expression): 46 """ 47 Reorder joins by topological sort order based on predicate references. 48 """ 49 for from_ in expression.find_all(exp.From): 50 head = from_.expressions[0] 51 parent = from_.parent 52 joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} 53 dag = {head.alias_or_name: []} 54 55 for name, join in joins.items(): 56 dag[name] = other_table_names(join, name) 57 58 parent.set( 59 "joins", 60 [joins[name] for name in tsort(dag) if name != head.alias_or_name], 61 ) 62 return expression 63 64 65def normalize(expression): 66 """ 67 Remove INNER and OUTER from joins as they are optional. 68 """ 69 for join in expression.find_all(exp.Join): 70 if join.kind != "CROSS": 71 join.set("kind", None) 72 return expression 73 74 75def other_table_names(join, exclude): 76 return [ 77 name 78 for name in (exp.column_table_names(join.args.get("on") or exp.true())) 79 if name != exclude 80 ]
def
optimize_joins(expression):
7def optimize_joins(expression): 8 """ 9 Removes cross joins if possible and reorder joins based on predicate dependencies. 10 11 Example: 12 >>> from sqlglot import parse_one 13 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 14 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 15 """ 16 for select in expression.find_all(exp.Select): 17 references = {} 18 cross_joins = [] 19 20 for join in select.args.get("joins", []): 21 name = join.this.alias_or_name 22 tables = other_table_names(join, name) 23 24 if tables: 25 for table in tables: 26 references[table] = references.get(table, []) + [join] 27 else: 28 cross_joins.append((name, join)) 29 30 for name, join in cross_joins: 31 for dep in references.get(name, []): 32 on = dep.args["on"] 33 on = on.replace(simplify(on)) 34 35 if isinstance(on, exp.Connector): 36 for predicate in on.flatten(): 37 if name in exp.column_table_names(predicate): 38 predicate.replace(exp.true()) 39 join.on(predicate, copy=False) 40 41 expression = reorder_joins(expression) 42 expression = normalize(expression) 43 return expression
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def
reorder_joins(expression):
46def reorder_joins(expression): 47 """ 48 Reorder joins by topological sort order based on predicate references. 49 """ 50 for from_ in expression.find_all(exp.From): 51 head = from_.expressions[0] 52 parent = from_.parent 53 joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} 54 dag = {head.alias_or_name: []} 55 56 for name, join in joins.items(): 57 dag[name] = other_table_names(join, name) 58 59 parent.set( 60 "joins", 61 [joins[name] for name in tsort(dag) if name != head.alias_or_name], 62 ) 63 return expression
Reorder joins by topological sort order based on predicate references.
def
normalize(expression):
66def normalize(expression): 67 """ 68 Remove INNER and OUTER from joins as they are optional. 69 """ 70 for join in expression.find_all(exp.Join): 71 if join.kind != "CROSS": 72 join.set("kind", None) 73 return expression
Remove INNER and OUTER from joins as they are optional.
def
other_table_names(join, exclude):