From 8f88a01462641cbf930b3c43b780565d0fb7d37e Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Jun 2023 20:53:34 +0200 Subject: Merging upstream version 16.4.0. Signed-off-by: Daniel Baumann --- docs/sqlglot/optimizer/optimize_joins.html | 298 ++++++++++++++--------------- 1 file changed, 146 insertions(+), 152 deletions(-) (limited to 'docs/sqlglot/optimizer/optimize_joins.html') diff --git a/docs/sqlglot/optimizer/optimize_joins.html b/docs/sqlglot/optimizer/optimize_joins.html index b914197..4ffa623 100644 --- a/docs/sqlglot/optimizer/optimize_joins.html +++ b/docs/sqlglot/optimizer/optimize_joins.html @@ -65,89 +65,88 @@ -
 1from sqlglot import exp
- 2from sqlglot.helper import tsort
- 3
- 4JOIN_ATTRS = ("on", "side", "kind", "using", "method")
- 5
- 6
- 7def optimize_joins(expression):
- 8    """
- 9    Removes cross joins if possible and reorder joins based on predicate dependencies.
+                        
 1from __future__ import annotations
+ 2
+ 3import typing as t
+ 4
+ 5from sqlglot import exp
+ 6from sqlglot.helper import tsort
+ 7
+ 8JOIN_ATTRS = ("on", "side", "kind", "using", "method")
+ 9
 10
-11    Example:
-12        >>> from sqlglot import parse_one
-13        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
-14        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
-15    """
-16    for select in expression.find_all(exp.Select):
-17        references = {}
-18        cross_joins = []
-19
-20        for join in select.args.get("joins", []):
-21            name = join.this.alias_or_name
-22            tables = other_table_names(join, name)
-23
-24            if tables:
-25                for table in tables:
-26                    references[table] = references.get(table, []) + [join]
-27            else:
-28                cross_joins.append((name, join))
-29
-30        for name, join in cross_joins:
-31            for dep in references.get(name, []):
-32                on = dep.args["on"]
+11def optimize_joins(expression):
+12    """
+13    Removes cross joins if possible and reorder joins based on predicate dependencies.
+14
+15    Example:
+16        >>> from sqlglot import parse_one
+17        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
+18        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
+19    """
+20
+21    for select in expression.find_all(exp.Select):
+22        references = {}
+23        cross_joins = []
+24
+25        for join in select.args.get("joins", []):
+26            tables = other_table_names(join)
+27
+28            if tables:
+29                for table in tables:
+30                    references[table] = references.get(table, []) + [join]
+31            else:
+32                cross_joins.append((join.alias_or_name, join))
 33
-34                if isinstance(on, exp.Connector):
-35                    for predicate in on.flatten():
-36                        if name in exp.column_table_names(predicate):
-37                            predicate.replace(exp.true())
-38                            join.on(predicate, copy=False)
-39
-40    expression = reorder_joins(expression)
-41    expression = normalize(expression)
-42    return expression
-43
-44
-45def reorder_joins(expression):
-46    """
-47    Reorder joins by topological sort order based on predicate references.
-48    """
-49    for from_ in expression.find_all(exp.From):
-50        head = from_.this
-51        parent = from_.parent
-52        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
-53        dag = {head.alias_or_name: []}
-54
-55        for name, join in joins.items():
-56            dag[name] = other_table_names(join, name)
-57
-58        parent.set(
-59            "joins",
-60            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
-61        )
-62    return expression
-63
-64
-65def normalize(expression):
-66    """
-67    Remove INNER and OUTER from joins as they are optional.
-68    """
-69    for join in expression.find_all(exp.Join):
-70        if not any(join.args.get(k) for k in JOIN_ATTRS):
-71            join.set("kind", "CROSS")
-72
-73        if join.kind != "CROSS":
-74            join.set("kind", None)
-75    return expression
-76
-77
-78def other_table_names(join, exclude):
-79    return [
-80        name
-81        for name in (exp.column_table_names(join.args.get("on") or exp.true()))
-82        if name != exclude
-83    ]
+34        for name, join in cross_joins:
+35            for dep in references.get(name, []):
+36                on = dep.args["on"]
+37
+38                if isinstance(on, exp.Connector):
+39                    if len(other_table_names(dep)) < 2:
+40                        continue
+41
+42                    for predicate in on.flatten():
+43                        if name in exp.column_table_names(predicate):
+44                            predicate.replace(exp.true())
+45                            join.on(predicate, copy=False)
+46
+47    expression = reorder_joins(expression)
+48    expression = normalize(expression)
+49    return expression
+50
+51
+52def reorder_joins(expression):
+53    """
+54    Reorder joins by topological sort order based on predicate references.
+55    """
+56    for from_ in expression.find_all(exp.From):
+57        parent = from_.parent
+58        joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
+59        dag = {name: other_table_names(join) for name, join in joins.items()}
+60        parent.set(
+61            "joins",
+62            [joins[name] for name in tsort(dag) if name != from_.alias_or_name],
+63        )
+64    return expression
+65
+66
+67def normalize(expression):
+68    """
+69    Remove INNER and OUTER from joins as they are optional.
+70    """
+71    for join in expression.find_all(exp.Join):
+72        if not any(join.args.get(k) for k in JOIN_ATTRS):
+73            join.set("kind", "CROSS")
+74
+75        if join.kind != "CROSS":
+76            join.set("kind", None)
+77    return expression
+78
+79
+80def other_table_names(join: exp.Join) -> t.Set[str]:
+81    on = join.args.get("on")
+82    return exp.column_table_names(on, join.alias_or_name) if on else set()
 
@@ -163,42 +162,45 @@
-
 8def optimize_joins(expression):
- 9    """
-10    Removes cross joins if possible and reorder joins based on predicate dependencies.
-11
-12    Example:
-13        >>> from sqlglot import parse_one
-14        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
-15        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
-16    """
-17    for select in expression.find_all(exp.Select):
-18        references = {}
-19        cross_joins = []
-20
-21        for join in select.args.get("joins", []):
-22            name = join.this.alias_or_name
-23            tables = other_table_names(join, name)
-24
-25            if tables:
-26                for table in tables:
-27                    references[table] = references.get(table, []) + [join]
-28            else:
-29                cross_joins.append((name, join))
-30
-31        for name, join in cross_joins:
-32            for dep in references.get(name, []):
-33                on = dep.args["on"]
+            
12def optimize_joins(expression):
+13    """
+14    Removes cross joins if possible and reorder joins based on predicate dependencies.
+15
+16    Example:
+17        >>> from sqlglot import parse_one
+18        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
+19        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
+20    """
+21
+22    for select in expression.find_all(exp.Select):
+23        references = {}
+24        cross_joins = []
+25
+26        for join in select.args.get("joins", []):
+27            tables = other_table_names(join)
+28
+29            if tables:
+30                for table in tables:
+31                    references[table] = references.get(table, []) + [join]
+32            else:
+33                cross_joins.append((join.alias_or_name, join))
 34
-35                if isinstance(on, exp.Connector):
-36                    for predicate in on.flatten():
-37                        if name in exp.column_table_names(predicate):
-38                            predicate.replace(exp.true())
-39                            join.on(predicate, copy=False)
-40
-41    expression = reorder_joins(expression)
-42    expression = normalize(expression)
-43    return expression
+35        for name, join in cross_joins:
+36            for dep in references.get(name, []):
+37                on = dep.args["on"]
+38
+39                if isinstance(on, exp.Connector):
+40                    if len(other_table_names(dep)) < 2:
+41                        continue
+42
+43                    for predicate in on.flatten():
+44                        if name in exp.column_table_names(predicate):
+45                            predicate.replace(exp.true())
+46                            join.on(predicate, copy=False)
+47
+48    expression = reorder_joins(expression)
+49    expression = normalize(expression)
+50    return expression
 
@@ -229,24 +231,19 @@
-
46def reorder_joins(expression):
-47    """
-48    Reorder joins by topological sort order based on predicate references.
-49    """
-50    for from_ in expression.find_all(exp.From):
-51        head = from_.this
-52        parent = from_.parent
-53        joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
-54        dag = {head.alias_or_name: []}
-55
-56        for name, join in joins.items():
-57            dag[name] = other_table_names(join, name)
-58
-59        parent.set(
-60            "joins",
-61            [joins[name] for name in tsort(dag) if name != head.alias_or_name],
-62        )
-63    return expression
+            
53def reorder_joins(expression):
+54    """
+55    Reorder joins by topological sort order based on predicate references.
+56    """
+57    for from_ in expression.find_all(exp.From):
+58        parent = from_.parent
+59        joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
+60        dag = {name: other_table_names(join) for name, join in joins.items()}
+61        parent.set(
+62            "joins",
+63            [joins[name] for name in tsort(dag) if name != from_.alias_or_name],
+64        )
+65    return expression
 
@@ -266,17 +263,17 @@
-
66def normalize(expression):
-67    """
-68    Remove INNER and OUTER from joins as they are optional.
-69    """
-70    for join in expression.find_all(exp.Join):
-71        if not any(join.args.get(k) for k in JOIN_ATTRS):
-72            join.set("kind", "CROSS")
-73
-74        if join.kind != "CROSS":
-75            join.set("kind", None)
-76    return expression
+            
68def normalize(expression):
+69    """
+70    Remove INNER and OUTER from joins as they are optional.
+71    """
+72    for join in expression.find_all(exp.Join):
+73        if not any(join.args.get(k) for k in JOIN_ATTRS):
+74            join.set("kind", "CROSS")
+75
+76        if join.kind != "CROSS":
+77            join.set("kind", None)
+78    return expression
 
@@ -290,18 +287,15 @@
def - other_table_names(join, exclude): + other_table_names(join: sqlglot.expressions.Join) -> Set[str]:
-
79def other_table_names(join, exclude):
-80    return [
-81        name
-82        for name in (exp.column_table_names(join.args.get("on") or exp.true()))
-83        if name != exclude
-84    ]
+            
81def other_table_names(join: exp.Join) -> t.Set[str]:
+82    on = join.args.get("on")
+83    return exp.column_table_names(on, join.alias_or_name) if on else set()
 
-- cgit v1.2.3