diff options
Diffstat (limited to 'sqlglot/optimizer/optimize_joins.py')
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py new file mode 100644 index 0000000..40e4ab1 --- /dev/null +++ b/sqlglot/optimizer/optimize_joins.py @@ -0,0 +1,75 @@ +from sqlglot import exp +from sqlglot.helper import tsort +from sqlglot.optimizer.simplify import simplify + + +def optimize_joins(expression): + """ + Removes cross joins if possible and reorder joins based on predicate dependencies. + """ + for select in expression.find_all(exp.Select): + references = {} + cross_joins = [] + + for join in select.args.get("joins", []): + name = join.this.alias_or_name + tables = other_table_names(join, name) + + if tables: + for table in tables: + references[table] = references.get(table, []) + [join] + else: + cross_joins.append((name, join)) + + for name, join in cross_joins: + for dep in references.get(name, []): + on = dep.args["on"] + on = on.replace(simplify(on)) + + if isinstance(on, exp.Connector): + for predicate in on.flatten(): + if name in exp.column_table_names(predicate): + predicate.replace(exp.TRUE) + join.on(predicate, copy=False) + + expression = reorder_joins(expression) + expression = normalize(expression) + return expression + + +def reorder_joins(expression): + """ + Reorder joins by topological sort order based on predicate references. + """ + for from_ in expression.find_all(exp.From): + head = from_.expressions[0] + parent = from_.parent + joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {head.alias_or_name: []} + + for name, join in joins.items(): + dag[name] = other_table_names(join, name) + + parent.set( + "joins", + [joins[name] for name in tsort(dag) if name != head.alias_or_name], + ) + return expression + + +def normalize(expression): + """ + Remove INNER and OUTER from joins as they are optional. + """ + for join in expression.find_all(exp.Join): + if join.kind != "CROSS": + join.set("kind", None) + return expression + + +def other_table_names(join, exclude): + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] |