From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- docs/sqlglot/optimizer/qualify_columns.html | 1515 ++++++++++++++------------- 1 file changed, 775 insertions(+), 740 deletions(-) (limited to 'docs/sqlglot/optimizer/qualify_columns.html') diff --git a/docs/sqlglot/optimizer/qualify_columns.html b/docs/sqlglot/optimizer/qualify_columns.html index 2d257f2..950bc67 100644 --- a/docs/sqlglot/optimizer/qualify_columns.html +++ b/docs/sqlglot/optimizer/qualify_columns.html @@ -77,539 +77,559 @@ -
  1import itertools
-  2import typing as t
-  3
-  4from sqlglot import alias, exp
-  5from sqlglot.errors import OptimizeError
-  6from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
-  7from sqlglot.optimizer.scope import Scope, traverse_scope
-  8from sqlglot.schema import ensure_schema
-  9
- 10
- 11def qualify_columns(expression, schema, expand_laterals=True):
- 12    """
- 13    Rewrite sqlglot AST to have fully qualified columns.
- 14
- 15    Example:
- 16        >>> import sqlglot
- 17        >>> schema = {"tbl": {"col": "INT"}}
- 18        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
- 19        >>> qualify_columns(expression, schema).sql()
- 20        'SELECT tbl.col AS col FROM tbl'
+                        
  1from __future__ import annotations
+  2
+  3import itertools
+  4import typing as t
+  5
+  6from sqlglot import alias, exp
+  7from sqlglot.errors import OptimizeError
+  8from sqlglot.helper import seq_get
+  9from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+ 10from sqlglot.schema import Schema, ensure_schema
+ 11
+ 12
+ 13def qualify_columns(
+ 14    expression: exp.Expression,
+ 15    schema: dict | Schema,
+ 16    expand_alias_refs: bool = True,
+ 17    infer_schema: t.Optional[bool] = None,
+ 18) -> exp.Expression:
+ 19    """
+ 20    Rewrite sqlglot AST to have fully qualified columns.
  21
- 22    Args:
- 23        expression (sqlglot.Expression): expression to qualify
- 24        schema (dict|sqlglot.optimizer.Schema): Database schema
- 25    Returns:
- 26        sqlglot.Expression: qualified expression
- 27    """
- 28    schema = ensure_schema(schema)
- 29
- 30    if not schema.mapping and expand_laterals:
- 31        expression = _expand_laterals(expression)
- 32
- 33    for scope in traverse_scope(expression):
- 34        resolver = Resolver(scope, schema)
- 35        _pop_table_column_aliases(scope.ctes)
- 36        _pop_table_column_aliases(scope.derived_tables)
- 37        using_column_tables = _expand_using(scope, resolver)
- 38        _qualify_columns(scope, resolver)
- 39        if not isinstance(scope.expression, exp.UDTF):
- 40            _expand_stars(scope, resolver, using_column_tables)
- 41            _qualify_outputs(scope)
- 42        _expand_alias_refs(scope, resolver)
- 43        _expand_group_by(scope, resolver)
- 44        _expand_order_by(scope)
+ 22    Example:
+ 23        >>> import sqlglot
+ 24        >>> schema = {"tbl": {"col": "INT"}}
+ 25        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+ 26        >>> qualify_columns(expression, schema).sql()
+ 27        'SELECT tbl.col AS col FROM tbl'
+ 28
+ 29    Args:
+ 30        expression: expression to qualify
+ 31        schema: Database schema
+ 32        expand_alias_refs: whether or not to expand references to aliases
+ 33        infer_schema: whether or not to infer the schema if missing
+ 34    Returns:
+ 35        sqlglot.Expression: qualified expression
+ 36    """
+ 37    schema = ensure_schema(schema)
+ 38    infer_schema = schema.empty if infer_schema is None else infer_schema
+ 39
+ 40    for scope in traverse_scope(expression):
+ 41        resolver = Resolver(scope, schema, infer_schema=infer_schema)
+ 42        _pop_table_column_aliases(scope.ctes)
+ 43        _pop_table_column_aliases(scope.derived_tables)
+ 44        using_column_tables = _expand_using(scope, resolver)
  45
- 46    if schema.mapping and expand_laterals:
- 47        expression = _expand_laterals(expression)
+ 46        if schema.empty and expand_alias_refs:
+ 47            _expand_alias_refs(scope, resolver)
  48
- 49    return expression
+ 49        _qualify_columns(scope, resolver)
  50
- 51
- 52def validate_qualify_columns(expression):
- 53    """Raise an `OptimizeError` if any columns aren't qualified"""
- 54    unqualified_columns = []
- 55    for scope in traverse_scope(expression):
- 56        if isinstance(scope.expression, exp.Select):
- 57            unqualified_columns.extend(scope.unqualified_columns)
- 58            if scope.external_columns and not scope.is_correlated_subquery:
- 59                column = scope.external_columns[0]
- 60                raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
+ 51        if not schema.empty and expand_alias_refs:
+ 52            _expand_alias_refs(scope, resolver)
+ 53
+ 54        if not isinstance(scope.expression, exp.UDTF):
+ 55            _expand_stars(scope, resolver, using_column_tables)
+ 56            _qualify_outputs(scope)
+ 57        _expand_group_by(scope, resolver)
+ 58        _expand_order_by(scope)
+ 59
+ 60    return expression
  61
- 62    if unqualified_columns:
- 63        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
- 64    return expression
- 65
- 66
- 67def _pop_table_column_aliases(derived_tables):
- 68    """
- 69    Remove table column aliases.
- 70
- 71    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
- 72    """
- 73    for derived_table in derived_tables:
- 74        table_alias = derived_table.args.get("alias")
- 75        if table_alias:
- 76            table_alias.args.pop("columns", None)
- 77
+ 62
+ 63def validate_qualify_columns(expression):
+ 64    """Raise an `OptimizeError` if any columns aren't qualified"""
+ 65    unqualified_columns = []
+ 66    for scope in traverse_scope(expression):
+ 67        if isinstance(scope.expression, exp.Select):
+ 68            unqualified_columns.extend(scope.unqualified_columns)
+ 69            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
+ 70                column = scope.external_columns[0]
+ 71                raise OptimizeError(
+ 72                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
+ 73                )
+ 74
+ 75    if unqualified_columns:
+ 76        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
+ 77    return expression
  78
- 79def _expand_using(scope, resolver):
- 80    joins = list(scope.find_all(exp.Join))
- 81    names = {join.this.alias for join in joins}
- 82    ordered = [key for key in scope.selected_sources if key not in names]
+ 79
+ 80def _pop_table_column_aliases(derived_tables):
+ 81    """
+ 82    Remove table column aliases.
  83
- 84    # Mapping of automatically joined column names to an ordered set of source names (dict).
- 85    column_tables = {}
- 86
- 87    for join in joins:
- 88        using = join.args.get("using")
- 89
- 90        if not using:
- 91            continue
- 92
- 93        join_table = join.this.alias_or_name
- 94
- 95        columns = {}
+ 84    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
+ 85    """
+ 86    for derived_table in derived_tables:
+ 87        table_alias = derived_table.args.get("alias")
+ 88        if table_alias:
+ 89            table_alias.args.pop("columns", None)
+ 90
+ 91
+ 92def _expand_using(scope, resolver):
+ 93    joins = list(scope.find_all(exp.Join))
+ 94    names = {join.this.alias for join in joins}
+ 95    ordered = [key for key in scope.selected_sources if key not in names]
  96
- 97        for k in scope.selected_sources:
- 98            if k in ordered:
- 99                for column in resolver.get_source_columns(k):
-100                    if column not in columns:
-101                        columns[column] = k
+ 97    # Mapping of automatically joined column names to an ordered set of source names (dict).
+ 98    column_tables = {}
+ 99
+100    for join in joins:
+101        using = join.args.get("using")
 102
-103        source_table = ordered[-1]
-104        ordered.append(join_table)
-105        join_columns = resolver.get_source_columns(join_table)
-106        conditions = []
+103        if not using:
+104            continue
+105
+106        join_table = join.this.alias_or_name
 107
-108        for identifier in using:
-109            identifier = identifier.name
-110            table = columns.get(identifier)
-111
-112            if not table or identifier not in join_columns:
-113                if columns and join_columns:
-114                    raise OptimizeError(f"Cannot automatically join: {identifier}")
+108        columns = {}
+109
+110        for k in scope.selected_sources:
+111            if k in ordered:
+112                for column in resolver.get_source_columns(k):
+113                    if column not in columns:
+114                        columns[column] = k
 115
-116            table = table or source_table
-117            conditions.append(
-118                exp.condition(
-119                    exp.EQ(
-120                        this=exp.column(identifier, table=table),
-121                        expression=exp.column(identifier, table=join_table),
-122                    )
-123                )
-124            )
-125
-126            # Set all values in the dict to None, because we only care about the key ordering
-127            tables = column_tables.setdefault(identifier, {})
-128            if table not in tables:
-129                tables[table] = None
-130            if join_table not in tables:
-131                tables[join_table] = None
-132
-133        join.args.pop("using")
-134        join.set("on", exp.and_(*conditions, copy=False))
-135
-136    if column_tables:
-137        for column in scope.columns:
-138            if not column.table and column.name in column_tables:
-139                tables = column_tables[column.name]
-140                coalesce = [exp.column(column.name, table=table) for table in tables]
-141                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
-142
-143                # Ensure selects keep their output name
-144                if isinstance(column.parent, exp.Select):
-145                    replacement = exp.alias_(replacement, alias=column.name)
-146
-147                scope.replace(column, replacement)
+116        source_table = ordered[-1]
+117        ordered.append(join_table)
+118        join_columns = resolver.get_source_columns(join_table)
+119        conditions = []
+120
+121        for identifier in using:
+122            identifier = identifier.name
+123            table = columns.get(identifier)
+124
+125            if not table or identifier not in join_columns:
+126                if columns and join_columns:
+127                    raise OptimizeError(f"Cannot automatically join: {identifier}")
+128
+129            table = table or source_table
+130            conditions.append(
+131                exp.condition(
+132                    exp.EQ(
+133                        this=exp.column(identifier, table=table),
+134                        expression=exp.column(identifier, table=join_table),
+135                    )
+136                )
+137            )
+138
+139            # Set all values in the dict to None, because we only care about the key ordering
+140            tables = column_tables.setdefault(identifier, {})
+141            if table not in tables:
+142                tables[table] = None
+143            if join_table not in tables:
+144                tables[join_table] = None
+145
+146        join.args.pop("using")
+147        join.set("on", exp.and_(*conditions, copy=False))
 148
-149    return column_tables
-150
-151
-152def _expand_alias_refs(scope, resolver):
-153    selects = {}
-154
-155    # Replace references to select aliases
-156    def transform(node, source_first=True):
-157        if isinstance(node, exp.Column) and not node.table:
-158            table = resolver.get_table(node.name)
+149    if column_tables:
+150        for column in scope.columns:
+151            if not column.table and column.name in column_tables:
+152                tables = column_tables[column.name]
+153                coalesce = [exp.column(column.name, table=table) for table in tables]
+154                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
+155
+156                # Ensure selects keep their output name
+157                if isinstance(column.parent, exp.Select):
+158                    replacement = alias(replacement, alias=column.name, copy=False)
 159
-160            # Source columns get priority over select aliases
-161            if source_first and table:
-162                node.set("table", table)
-163                return node
+160                scope.replace(column, replacement)
+161
+162    return column_tables
+163
 164
-165            if not selects:
-166                for s in scope.selects:
-167                    selects[s.alias_or_name] = s
-168            select = selects.get(node.name)
-169
-170            if select:
-171                scope.clear_cache()
-172                if isinstance(select, exp.Alias):
-173                    select = select.this
-174                return select.copy()
-175
-176            node.set("table", table)
-177        elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
-178            exp.replace_children(node, transform, source_first)
-179
-180        return node
-181
-182    for select in scope.expression.selects:
-183        transform(select)
-184
-185    for modifier, source_first in (
-186        ("where", True),
-187        ("group", True),
-188        ("having", False),
-189    ):
-190        transform(scope.expression.args.get(modifier), source_first=source_first)
-191
-192
-193def _expand_group_by(scope, resolver):
-194    group = scope.expression.args.get("group")
-195    if not group:
-196        return
-197
-198    group.set("expressions", _expand_positional_references(scope, group.expressions))
-199    scope.expression.set("group", group)
+165def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
+166    expression = scope.expression
+167
+168    if not isinstance(expression, exp.Select):
+169        return
+170
+171    alias_to_expression: t.Dict[str, exp.Expression] = {}
+172
+173    def replace_columns(
+174        node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
+175    ):
+176        if not node:
+177            return
+178
+179        for column, *_ in walk_in_scope(node):
+180            if not isinstance(column, exp.Column):
+181                continue
+182            table = resolver.get_table(column.name) if resolve_agg and not column.table else None
+183            if table and column.find_ancestor(exp.AggFunc):
+184                column.set("table", table)
+185            elif expand and not column.table and column.name in alias_to_expression:
+186                column.replace(alias_to_expression[column.name].copy())
+187
+188    for projection in scope.selects:
+189        replace_columns(projection)
+190
+191        if isinstance(projection, exp.Alias):
+192            alias_to_expression[projection.alias] = projection.this
+193
+194    replace_columns(expression.args.get("where"))
+195    replace_columns(expression.args.get("group"))
+196    replace_columns(expression.args.get("having"), resolve_agg=True)
+197    replace_columns(expression.args.get("qualify"), resolve_agg=True)
+198    replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
+199    scope.clear_cache()
 200
 201
-202def _expand_order_by(scope):
-203    order = scope.expression.args.get("order")
-204    if not order:
+202def _expand_group_by(scope, resolver):
+203    group = scope.expression.args.get("group")
+204    if not group:
 205        return
 206
-207    ordereds = order.expressions
-208    for ordered, new_expression in zip(
-209        ordereds,
-210        _expand_positional_references(scope, (o.this for o in ordereds)),
-211    ):
-212        ordered.set("this", new_expression)
-213
-214
-215def _expand_positional_references(scope, expressions):
-216    new_nodes = []
-217    for node in expressions:
-218        if node.is_int:
-219            try:
-220                select = scope.selects[int(node.name) - 1]
-221            except IndexError:
-222                raise OptimizeError(f"Unknown output column: {node.name}")
-223            if isinstance(select, exp.Alias):
-224                select = select.this
-225            new_nodes.append(select.copy())
-226            scope.clear_cache()
-227        else:
-228            new_nodes.append(node)
-229
-230    return new_nodes
-231
-232
-233def _qualify_columns(scope, resolver):
-234    """Disambiguate columns, ensuring each column specifies a source"""
-235    for column in scope.columns:
-236        column_table = column.table
-237        column_name = column.name
+207    group.set("expressions", _expand_positional_references(scope, group.expressions))
+208    scope.expression.set("group", group)
+209
+210
+211def _expand_order_by(scope):
+212    order = scope.expression.args.get("order")
+213    if not order:
+214        return
+215
+216    ordereds = order.expressions
+217    for ordered, new_expression in zip(
+218        ordereds,
+219        _expand_positional_references(scope, (o.this for o in ordereds)),
+220    ):
+221        ordered.set("this", new_expression)
+222
+223
+224def _expand_positional_references(scope, expressions):
+225    new_nodes = []
+226    for node in expressions:
+227        if node.is_int:
+228            try:
+229                select = scope.selects[int(node.name) - 1]
+230            except IndexError:
+231                raise OptimizeError(f"Unknown output column: {node.name}")
+232            if isinstance(select, exp.Alias):
+233                select = select.this
+234            new_nodes.append(select.copy())
+235            scope.clear_cache()
+236        else:
+237            new_nodes.append(node)
 238
-239        if column_table and column_table in scope.sources:
-240            source_columns = resolver.get_source_columns(column_table)
-241            if source_columns and column_name not in source_columns and "*" not in source_columns:
-242                raise OptimizeError(f"Unknown column: {column_name}")
-243
-244        if not column_table:
-245            column_table = resolver.get_table(column_name)
-246
-247            # column_table can be a '' because bigquery unnest has no table alias
-248            if column_table:
-249                column.set("table", column_table)
-250        elif column_table not in scope.sources and (
-251            not scope.parent or column_table not in scope.parent.sources
-252        ):
-253            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
-254            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
-255
-256            root, *parts = column.parts
-257
-258            if root.name in scope.sources:
-259                # struct is already qualified, but we still need to change the AST representation
-260                column_table = root
-261                root, *parts = parts
-262            else:
-263                column_table = resolver.get_table(root.name)
-264
-265            if column_table:
-266                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
-267
-268    columns_missing_from_scope = []
-269
-270    # Determine whether each reference in the order by clause is to a column or an alias.
-271    order = scope.expression.args.get("order")
+239    return new_nodes
+240
+241
+242def _qualify_columns(scope, resolver):
+243    """Disambiguate columns, ensuring each column specifies a source"""
+244    for column in scope.columns:
+245        column_table = column.table
+246        column_name = column.name
+247
+248        if column_table and column_table in scope.sources:
+249            source_columns = resolver.get_source_columns(column_table)
+250            if source_columns and column_name not in source_columns and "*" not in source_columns:
+251                raise OptimizeError(f"Unknown column: {column_name}")
+252
+253        if not column_table:
+254            if scope.pivots and not column.find_ancestor(exp.Pivot):
+255                # If the column is under the Pivot expression, we need to qualify it
+256                # using the name of the pivoted source instead of the pivot's alias
+257                column.set("table", exp.to_identifier(scope.pivots[0].alias))
+258                continue
+259
+260            column_table = resolver.get_table(column_name)
+261
+262            # column_table can be a '' because bigquery unnest has no table alias
+263            if column_table:
+264                column.set("table", column_table)
+265        elif column_table not in scope.sources and (
+266            not scope.parent or column_table not in scope.parent.sources
+267        ):
+268            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
+269            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
+270
+271            root, *parts = column.parts
 272
-273    if order:
-274        for ordered in order.expressions:
-275            for column in ordered.find_all(exp.Column):
-276                if (
-277                    not column.table
-278                    and column.parent is not ordered
-279                    and column.name in resolver.all_columns
-280                ):
-281                    columns_missing_from_scope.append(column)
+273            if root.name in scope.sources:
+274                # struct is already qualified, but we still need to change the AST representation
+275                column_table = root
+276                root, *parts = parts
+277            else:
+278                column_table = resolver.get_table(root.name)
+279
+280            if column_table:
+281                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 282
-283    # Determine whether each reference in the having clause is to a column or an alias.
-284    having = scope.expression.args.get("having")
-285
-286    if having:
-287        for column in having.find_all(exp.Column):
-288            if (
-289                not column.table
-290                and column.find_ancestor(exp.AggFunc)
-291                and column.name in resolver.all_columns
-292            ):
-293                columns_missing_from_scope.append(column)
-294
-295    for column in columns_missing_from_scope:
-296        column_table = resolver.get_table(column.name)
-297
-298        if column_table:
-299            column.set("table", column_table)
-300
-301
-302def _expand_stars(scope, resolver, using_column_tables):
-303    """Expand stars to lists of column selections"""
-304
-305    new_selections = []
-306    except_columns = {}
-307    replace_columns = {}
-308    coalesced_columns = set()
-309
-310    for expression in scope.selects:
-311        if isinstance(expression, exp.Star):
-312            tables = list(scope.selected_sources)
-313            _add_except_columns(expression, tables, except_columns)
-314            _add_replace_columns(expression, tables, replace_columns)
-315        elif expression.is_star:
-316            tables = [expression.table]
-317            _add_except_columns(expression.this, tables, except_columns)
-318            _add_replace_columns(expression.this, tables, replace_columns)
-319        else:
-320            new_selections.append(expression)
-321            continue
-322
-323        for table in tables:
-324            if table not in scope.sources:
-325                raise OptimizeError(f"Unknown table: {table}")
-326            columns = resolver.get_source_columns(table, only_visible=True)
-327
-328            if columns and "*" not in columns:
-329                table_id = id(table)
-330                for name in columns:
-331                    if name in using_column_tables and table in using_column_tables[name]:
-332                        if name in coalesced_columns:
-333                            continue
-334
-335                        coalesced_columns.add(name)
-336                        tables = using_column_tables[name]
-337                        coalesce = [exp.column(name, table=table) for table in tables]
-338
-339                        new_selections.append(
-340                            exp.alias_(
-341                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
-342                            )
-343                        )
-344                    elif name not in except_columns.get(table_id, set()):
-345                        alias_ = replace_columns.get(table_id, {}).get(name, name)
-346                        column = exp.column(name, table)
-347                        new_selections.append(alias(column, alias_) if alias_ != name else column)
-348            else:
-349                return
-350    scope.expression.set("expressions", new_selections)
-351
-352
-353def _add_except_columns(expression, tables, except_columns):
-354    except_ = expression.args.get("except")
-355
-356    if not except_:
-357        return
-358
-359    columns = {e.name for e in except_}
-360
-361    for table in tables:
-362        except_columns[id(table)] = columns
-363
-364
-365def _add_replace_columns(expression, tables, replace_columns):
-366    replace = expression.args.get("replace")
+283    for pivot in scope.pivots:
+284        for column in pivot.find_all(exp.Column):
+285            if not column.table and column.name in resolver.all_columns:
+286                column_table = resolver.get_table(column.name)
+287                if column_table:
+288                    column.set("table", column_table)
+289
+290
+291def _expand_stars(scope, resolver, using_column_tables):
+292    """Expand stars to lists of column selections"""
+293
+294    new_selections = []
+295    except_columns = {}
+296    replace_columns = {}
+297    coalesced_columns = set()
+298
+299    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
+300    pivot_columns = None
+301    pivot_output_columns = None
+302    pivot = seq_get(scope.pivots, 0)
+303
+304    has_pivoted_source = pivot and not pivot.args.get("unpivot")
+305    if has_pivoted_source:
+306        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
+307
+308        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
+309        if not pivot_output_columns:
+310            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
+311
+312    for expression in scope.selects:
+313        if isinstance(expression, exp.Star):
+314            tables = list(scope.selected_sources)
+315            _add_except_columns(expression, tables, except_columns)
+316            _add_replace_columns(expression, tables, replace_columns)
+317        elif expression.is_star:
+318            tables = [expression.table]
+319            _add_except_columns(expression.this, tables, except_columns)
+320            _add_replace_columns(expression.this, tables, replace_columns)
+321        else:
+322            new_selections.append(expression)
+323            continue
+324
+325        for table in tables:
+326            if table not in scope.sources:
+327                raise OptimizeError(f"Unknown table: {table}")
+328
+329            columns = resolver.get_source_columns(table, only_visible=True)
+330
+331            if columns and "*" not in columns:
+332                if has_pivoted_source:
+333                    implicit_columns = [col for col in columns if col not in pivot_columns]
+334                    new_selections.extend(
+335                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
+336                        for name in implicit_columns + pivot_output_columns
+337                    )
+338                    continue
+339
+340                table_id = id(table)
+341                for name in columns:
+342                    if name in using_column_tables and table in using_column_tables[name]:
+343                        if name in coalesced_columns:
+344                            continue
+345
+346                        coalesced_columns.add(name)
+347                        tables = using_column_tables[name]
+348                        coalesce = [exp.column(name, table=table) for table in tables]
+349
+350                        new_selections.append(
+351                            alias(
+352                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
+353                                alias=name,
+354                                copy=False,
+355                            )
+356                        )
+357                    elif name not in except_columns.get(table_id, set()):
+358                        alias_ = replace_columns.get(table_id, {}).get(name, name)
+359                        column = exp.column(name, table=table)
+360                        new_selections.append(
+361                            alias(column, alias_, copy=False) if alias_ != name else column
+362                        )
+363            else:
+364                return
+365
+366    scope.expression.set("expressions", new_selections)
 367
-368    if not replace:
-369        return
-370
-371    columns = {e.this.name: e.alias for e in replace}
-372
-373    for table in tables:
-374        replace_columns[id(table)] = columns
-375
+368
+369def _add_except_columns(expression, tables, except_columns):
+370    except_ = expression.args.get("except")
+371
+372    if not except_:
+373        return
+374
+375    columns = {e.name for e in except_}
 376
-377def _qualify_outputs(scope):
-378    """Ensure all output columns are aliased"""
-379    new_selections = []
+377    for table in tables:
+378        except_columns[id(table)] = columns
+379
 380
-381    for i, (selection, aliased_column) in enumerate(
-382        itertools.zip_longest(scope.selects, scope.outer_column_list)
-383    ):
-384        if isinstance(selection, exp.Subquery):
-385            if not selection.output_name:
-386                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
-387        elif not isinstance(selection, exp.Alias) and not selection.is_star:
-388            alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
-389            alias_.set("this", selection)
-390            selection = alias_
+381def _add_replace_columns(expression, tables, replace_columns):
+382    replace = expression.args.get("replace")
+383
+384    if not replace:
+385        return
+386
+387    columns = {e.this.name: e.alias for e in replace}
+388
+389    for table in tables:
+390        replace_columns[id(table)] = columns
 391
-392        if aliased_column:
-393            selection.set("alias", exp.to_identifier(aliased_column))
-394
-395        new_selections.append(selection)
+392
+393def _qualify_outputs(scope):
+394    """Ensure all output columns are aliased"""
+395    new_selections = []
 396
-397    scope.expression.set("expressions", new_selections)
-398
-399
-400class Resolver:
-401    """
-402    Helper for resolving columns.
-403
-404    This is a class so we can lazily load some things and easily share them across functions.
-405    """
-406
-407    def __init__(self, scope, schema):
-408        self.scope = scope
-409        self.schema = schema
-410        self._source_columns = None
-411        self._unambiguous_columns = None
-412        self._all_columns = None
+397    for i, (selection, aliased_column) in enumerate(
+398        itertools.zip_longest(scope.selects, scope.outer_column_list)
+399    ):
+400        if isinstance(selection, exp.Subquery):
+401            if not selection.output_name:
+402                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
+403        elif not isinstance(selection, exp.Alias) and not selection.is_star:
+404            selection = alias(
+405                selection,
+406                alias=selection.output_name or f"_col_{i}",
+407                quoted=True
+408                if isinstance(selection, exp.Column) and selection.this.quoted
+409                else None,
+410            )
+411        if aliased_column:
+412            selection.set("alias", exp.to_identifier(aliased_column))
 413
-414    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
-415        """
-416        Get the table for a column name.
+414        new_selections.append(selection)
+415
+416    scope.expression.set("expressions", new_selections)
 417
-418        Args:
-419            column_name: The column name to find the table for.
-420        Returns:
-421            The table name if it can be found/inferred.
-422        """
-423        if self._unambiguous_columns is None:
-424            self._unambiguous_columns = self._get_unambiguous_columns(
-425                self._get_all_source_columns()
-426            )
-427
-428        table_name = self._unambiguous_columns.get(column_name)
-429
-430        if not table_name:
-431            sources_without_schema = tuple(
-432                source
-433                for source, columns in self._get_all_source_columns().items()
-434                if not columns or "*" in columns
-435            )
-436            if len(sources_without_schema) == 1:
-437                table_name = sources_without_schema[0]
-438
-439        if table_name not in self.scope.selected_sources:
-440            return exp.to_identifier(table_name)
-441
-442        node, _ = self.scope.selected_sources.get(table_name)
-443
-444        if isinstance(node, exp.Subqueryable):
-445            while node and node.alias != table_name:
-446                node = node.parent
+418
+419class Resolver:
+420    """
+421    Helper for resolving columns.
+422
+423    This is a class so we can lazily load some things and easily share them across functions.
+424    """
+425
+426    def __init__(self, scope, schema, infer_schema: bool = True):
+427        self.scope = scope
+428        self.schema = schema
+429        self._source_columns = None
+430        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
+431        self._all_columns = None
+432        self._infer_schema = infer_schema
+433
+434    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
+435        """
+436        Get the table for a column name.
+437
+438        Args:
+439            column_name: The column name to find the table for.
+440        Returns:
+441            The table name if it can be found/inferred.
+442        """
+443        if self._unambiguous_columns is None:
+444            self._unambiguous_columns = self._get_unambiguous_columns(
+445                self._get_all_source_columns()
+446            )
 447
-448        node_alias = node.args.get("alias")
-449        if node_alias:
-450            return node_alias.this
-451
-452        return exp.to_identifier(
-453            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
-454        )
-455
-456    @property
-457    def all_columns(self):
-458        """All available columns of all sources in this scope"""
-459        if self._all_columns is None:
-460            self._all_columns = {
-461                column for columns in self._get_all_source_columns().values() for column in columns
-462            }
-463        return self._all_columns
-464
-465    def get_source_columns(self, name, only_visible=False):
-466        """Resolve the source columns for a given source `name`"""
-467        if name not in self.scope.sources:
-468            raise OptimizeError(f"Unknown table: {name}")
-469
-470        source = self.scope.sources[name]
+448        table_name = self._unambiguous_columns.get(column_name)
+449
+450        if not table_name and self._infer_schema:
+451            sources_without_schema = tuple(
+452                source
+453                for source, columns in self._get_all_source_columns().items()
+454                if not columns or "*" in columns
+455            )
+456            if len(sources_without_schema) == 1:
+457                table_name = sources_without_schema[0]
+458
+459        if table_name not in self.scope.selected_sources:
+460            return exp.to_identifier(table_name)
+461
+462        node, _ = self.scope.selected_sources.get(table_name)
+463
+464        if isinstance(node, exp.Subqueryable):
+465            while node and node.alias != table_name:
+466                node = node.parent
+467
+468        node_alias = node.args.get("alias")
+469        if node_alias:
+470            return exp.to_identifier(node_alias.this)
 471
-472        # If referencing a table, return the columns from the schema
-473        if isinstance(source, exp.Table):
-474            return self.schema.column_names(source, only_visible)
+472        return exp.to_identifier(
+473            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
+474        )
 475
-476        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
-477            return source.expression.alias_column_names
-478
-479        # Otherwise, if referencing another scope, return that scope's named selects
-480        return source.expression.named_selects
-481
-482    def _get_all_source_columns(self):
-483        if self._source_columns is None:
-484            self._source_columns = {
-485                k: self.get_source_columns(k)
-486                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
-487            }
-488        return self._source_columns
+476    @property
+477    def all_columns(self):
+478        """All available columns of all sources in this scope"""
+479        if self._all_columns is None:
+480            self._all_columns = {
+481                column for columns in self._get_all_source_columns().values() for column in columns
+482            }
+483        return self._all_columns
+484
+485    def get_source_columns(self, name, only_visible=False):
+486        """Resolve the source columns for a given source `name`"""
+487        if name not in self.scope.sources:
+488            raise OptimizeError(f"Unknown table: {name}")
 489
-490    def _get_unambiguous_columns(self, source_columns):
-491        """
-492        Find all the unambiguous columns in sources.
-493
-494        Args:
-495            source_columns (dict): Mapping of names to source columns
-496        Returns:
-497            dict: Mapping of column name to source name
-498        """
-499        if not source_columns:
-500            return {}
+490        source = self.scope.sources[name]
+491
+492        # If referencing a table, return the columns from the schema
+493        if isinstance(source, exp.Table):
+494            return self.schema.column_names(source, only_visible)
+495
+496        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+497            return source.expression.alias_column_names
+498
+499        # Otherwise, if referencing another scope, return that scope's named selects
+500        return source.expression.named_selects
 501
-502        source_columns = list(source_columns.items())
-503
-504        first_table, first_columns = source_columns[0]
-505        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
-506        all_columns = set(unambiguous_columns)
-507
-508        for table, columns in source_columns[1:]:
-509            unique = self._find_unique_columns(columns)
-510            ambiguous = set(all_columns).intersection(unique)
-511            all_columns.update(columns)
-512            for column in ambiguous:
-513                unambiguous_columns.pop(column, None)
-514            for column in unique.difference(ambiguous):
-515                unambiguous_columns[column] = table
-516
-517        return unambiguous_columns
-518
-519    @staticmethod
-520    def _find_unique_columns(columns):
-521        """
-522        Find the unique columns in a list of columns.
+502    def _get_all_source_columns(self):
+503        if self._source_columns is None:
+504            self._source_columns = {
+505                k: self.get_source_columns(k)
+506                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
+507            }
+508        return self._source_columns
+509
+510    def _get_unambiguous_columns(self, source_columns):
+511        """
+512        Find all the unambiguous columns in sources.
+513
+514        Args:
+515            source_columns (dict): Mapping of names to source columns
+516        Returns:
+517            dict: Mapping of column name to source name
+518        """
+519        if not source_columns:
+520            return {}
+521
+522        source_columns = list(source_columns.items())
 523
-524        Example:
-525            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
-526            ['a', 'c']
+524        first_table, first_columns = source_columns[0]
+525        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
+526        all_columns = set(unambiguous_columns)
 527
-528        This is necessary because duplicate column names are ambiguous.
-529        """
-530        counts = {}
-531        for column in columns:
-532            counts[column] = counts.get(column, 0) + 1
-533        return {column for column, count in counts.items() if count == 1}
+528        for table, columns in source_columns[1:]:
+529            unique = self._find_unique_columns(columns)
+530            ambiguous = set(all_columns).intersection(unique)
+531            all_columns.update(columns)
+532            for column in ambiguous:
+533                unambiguous_columns.pop(column, None)
+534            for column in unique.difference(ambiguous):
+535                unambiguous_columns[column] = table
+536
+537        return unambiguous_columns
+538
+539    @staticmethod
+540    def _find_unique_columns(columns):
+541        """
+542        Find the unique columns in a list of columns.
+543
+544        Example:
+545            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
+546            ['a', 'c']
+547
+548        This is necessary because duplicate column names are ambiguous.
+549        """
+550        counts = {}
+551        for column in columns:
+552            counts[column] = counts.get(column, 0) + 1
+553        return {column for column, count in counts.items() if count == 1}
 
@@ -619,51 +639,60 @@
def - qualify_columns(expression, schema, expand_laterals=True): + qualify_columns( expression: sqlglot.expressions.Expression, schema: dict | sqlglot.schema.Schema, expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
-
12def qualify_columns(expression, schema, expand_laterals=True):
-13    """
-14    Rewrite sqlglot AST to have fully qualified columns.
-15
-16    Example:
-17        >>> import sqlglot
-18        >>> schema = {"tbl": {"col": "INT"}}
-19        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
-20        >>> qualify_columns(expression, schema).sql()
-21        'SELECT tbl.col AS col FROM tbl'
+            
14def qualify_columns(
+15    expression: exp.Expression,
+16    schema: dict | Schema,
+17    expand_alias_refs: bool = True,
+18    infer_schema: t.Optional[bool] = None,
+19) -> exp.Expression:
+20    """
+21    Rewrite sqlglot AST to have fully qualified columns.
 22
-23    Args:
-24        expression (sqlglot.Expression): expression to qualify
-25        schema (dict|sqlglot.optimizer.Schema): Database schema
-26    Returns:
-27        sqlglot.Expression: qualified expression
-28    """
-29    schema = ensure_schema(schema)
-30
-31    if not schema.mapping and expand_laterals:
-32        expression = _expand_laterals(expression)
-33
-34    for scope in traverse_scope(expression):
-35        resolver = Resolver(scope, schema)
-36        _pop_table_column_aliases(scope.ctes)
-37        _pop_table_column_aliases(scope.derived_tables)
-38        using_column_tables = _expand_using(scope, resolver)
-39        _qualify_columns(scope, resolver)
-40        if not isinstance(scope.expression, exp.UDTF):
-41            _expand_stars(scope, resolver, using_column_tables)
-42            _qualify_outputs(scope)
-43        _expand_alias_refs(scope, resolver)
-44        _expand_group_by(scope, resolver)
-45        _expand_order_by(scope)
+23    Example:
+24        >>> import sqlglot
+25        >>> schema = {"tbl": {"col": "INT"}}
+26        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+27        >>> qualify_columns(expression, schema).sql()
+28        'SELECT tbl.col AS col FROM tbl'
+29
+30    Args:
+31        expression: expression to qualify
+32        schema: Database schema
+33        expand_alias_refs: whether or not to expand references to aliases
+34        infer_schema: whether or not to infer the schema if missing
+35    Returns:
+36        sqlglot.Expression: qualified expression
+37    """
+38    schema = ensure_schema(schema)
+39    infer_schema = schema.empty if infer_schema is None else infer_schema
+40
+41    for scope in traverse_scope(expression):
+42        resolver = Resolver(scope, schema, infer_schema=infer_schema)
+43        _pop_table_column_aliases(scope.ctes)
+44        _pop_table_column_aliases(scope.derived_tables)
+45        using_column_tables = _expand_using(scope, resolver)
 46
-47    if schema.mapping and expand_laterals:
-48        expression = _expand_laterals(expression)
+47        if schema.empty and expand_alias_refs:
+48            _expand_alias_refs(scope, resolver)
 49
-50    return expression
+50        _qualify_columns(scope, resolver)
+51
+52        if not schema.empty and expand_alias_refs:
+53            _expand_alias_refs(scope, resolver)
+54
+55        if not isinstance(scope.expression, exp.UDTF):
+56            _expand_stars(scope, resolver, using_column_tables)
+57            _qualify_outputs(scope)
+58        _expand_group_by(scope, resolver)
+59        _expand_order_by(scope)
+60
+61    return expression
 
@@ -685,8 +714,10 @@
Arguments:
    -
  • expression (sqlglot.Expression): expression to qualify
  • -
  • schema (dict|sqlglot.optimizer.Schema): Database schema
  • +
  • expression: expression to qualify
  • +
  • schema: Database schema
  • +
  • expand_alias_refs: whether or not to expand references to aliases
  • +
  • infer_schema: whether or not to infer the schema if missing
Returns:
@@ -709,19 +740,21 @@
-
53def validate_qualify_columns(expression):
-54    """Raise an `OptimizeError` if any columns aren't qualified"""
-55    unqualified_columns = []
-56    for scope in traverse_scope(expression):
-57        if isinstance(scope.expression, exp.Select):
-58            unqualified_columns.extend(scope.unqualified_columns)
-59            if scope.external_columns and not scope.is_correlated_subquery:
-60                column = scope.external_columns[0]
-61                raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
-62
-63    if unqualified_columns:
-64        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
-65    return expression
+            
64def validate_qualify_columns(expression):
+65    """Raise an `OptimizeError` if any columns aren't qualified"""
+66    unqualified_columns = []
+67    for scope in traverse_scope(expression):
+68        if isinstance(scope.expression, exp.Select):
+69            unqualified_columns.extend(scope.unqualified_columns)
+70            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
+71                column = scope.external_columns[0]
+72                raise OptimizeError(
+73                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
+74                )
+75
+76    if unqualified_columns:
+77        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
+78    return expression
 
@@ -741,140 +774,141 @@
-
401class Resolver:
-402    """
-403    Helper for resolving columns.
-404
-405    This is a class so we can lazily load some things and easily share them across functions.
-406    """
-407
-408    def __init__(self, scope, schema):
-409        self.scope = scope
-410        self.schema = schema
-411        self._source_columns = None
-412        self._unambiguous_columns = None
-413        self._all_columns = None
-414
-415    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
-416        """
-417        Get the table for a column name.
-418
-419        Args:
-420            column_name: The column name to find the table for.
-421        Returns:
-422            The table name if it can be found/inferred.
-423        """
-424        if self._unambiguous_columns is None:
-425            self._unambiguous_columns = self._get_unambiguous_columns(
-426                self._get_all_source_columns()
-427            )
-428
-429        table_name = self._unambiguous_columns.get(column_name)
-430
-431        if not table_name:
-432            sources_without_schema = tuple(
-433                source
-434                for source, columns in self._get_all_source_columns().items()
-435                if not columns or "*" in columns
-436            )
-437            if len(sources_without_schema) == 1:
-438                table_name = sources_without_schema[0]
-439
-440        if table_name not in self.scope.selected_sources:
-441            return exp.to_identifier(table_name)
-442
-443        node, _ = self.scope.selected_sources.get(table_name)
-444
-445        if isinstance(node, exp.Subqueryable):
-446            while node and node.alias != table_name:
-447                node = node.parent
+            
420class Resolver:
+421    """
+422    Helper for resolving columns.
+423
+424    This is a class so we can lazily load some things and easily share them across functions.
+425    """
+426
+427    def __init__(self, scope, schema, infer_schema: bool = True):
+428        self.scope = scope
+429        self.schema = schema
+430        self._source_columns = None
+431        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
+432        self._all_columns = None
+433        self._infer_schema = infer_schema
+434
+435    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
+436        """
+437        Get the table for a column name.
+438
+439        Args:
+440            column_name: The column name to find the table for.
+441        Returns:
+442            The table name if it can be found/inferred.
+443        """
+444        if self._unambiguous_columns is None:
+445            self._unambiguous_columns = self._get_unambiguous_columns(
+446                self._get_all_source_columns()
+447            )
 448
-449        node_alias = node.args.get("alias")
-450        if node_alias:
-451            return node_alias.this
-452
-453        return exp.to_identifier(
-454            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
-455        )
-456
-457    @property
-458    def all_columns(self):
-459        """All available columns of all sources in this scope"""
-460        if self._all_columns is None:
-461            self._all_columns = {
-462                column for columns in self._get_all_source_columns().values() for column in columns
-463            }
-464        return self._all_columns
-465
-466    def get_source_columns(self, name, only_visible=False):
-467        """Resolve the source columns for a given source `name`"""
-468        if name not in self.scope.sources:
-469            raise OptimizeError(f"Unknown table: {name}")
-470
-471        source = self.scope.sources[name]
+449        table_name = self._unambiguous_columns.get(column_name)
+450
+451        if not table_name and self._infer_schema:
+452            sources_without_schema = tuple(
+453                source
+454                for source, columns in self._get_all_source_columns().items()
+455                if not columns or "*" in columns
+456            )
+457            if len(sources_without_schema) == 1:
+458                table_name = sources_without_schema[0]
+459
+460        if table_name not in self.scope.selected_sources:
+461            return exp.to_identifier(table_name)
+462
+463        node, _ = self.scope.selected_sources.get(table_name)
+464
+465        if isinstance(node, exp.Subqueryable):
+466            while node and node.alias != table_name:
+467                node = node.parent
+468
+469        node_alias = node.args.get("alias")
+470        if node_alias:
+471            return exp.to_identifier(node_alias.this)
 472
-473        # If referencing a table, return the columns from the schema
-474        if isinstance(source, exp.Table):
-475            return self.schema.column_names(source, only_visible)
+473        return exp.to_identifier(
+474            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
+475        )
 476
-477        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
-478            return source.expression.alias_column_names
-479
-480        # Otherwise, if referencing another scope, return that scope's named selects
-481        return source.expression.named_selects
-482
-483    def _get_all_source_columns(self):
-484        if self._source_columns is None:
-485            self._source_columns = {
-486                k: self.get_source_columns(k)
-487                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
-488            }
-489        return self._source_columns
+477    @property
+478    def all_columns(self):
+479        """All available columns of all sources in this scope"""
+480        if self._all_columns is None:
+481            self._all_columns = {
+482                column for columns in self._get_all_source_columns().values() for column in columns
+483            }
+484        return self._all_columns
+485
+486    def get_source_columns(self, name, only_visible=False):
+487        """Resolve the source columns for a given source `name`"""
+488        if name not in self.scope.sources:
+489            raise OptimizeError(f"Unknown table: {name}")
 490
-491    def _get_unambiguous_columns(self, source_columns):
-492        """
-493        Find all the unambiguous columns in sources.
-494
-495        Args:
-496            source_columns (dict): Mapping of names to source columns
-497        Returns:
-498            dict: Mapping of column name to source name
-499        """
-500        if not source_columns:
-501            return {}
+491        source = self.scope.sources[name]
+492
+493        # If referencing a table, return the columns from the schema
+494        if isinstance(source, exp.Table):
+495            return self.schema.column_names(source, only_visible)
+496
+497        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+498            return source.expression.alias_column_names
+499
+500        # Otherwise, if referencing another scope, return that scope's named selects
+501        return source.expression.named_selects
 502
-503        source_columns = list(source_columns.items())
-504
-505        first_table, first_columns = source_columns[0]
-506        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
-507        all_columns = set(unambiguous_columns)
-508
-509        for table, columns in source_columns[1:]:
-510            unique = self._find_unique_columns(columns)
-511            ambiguous = set(all_columns).intersection(unique)
-512            all_columns.update(columns)
-513            for column in ambiguous:
-514                unambiguous_columns.pop(column, None)
-515            for column in unique.difference(ambiguous):
-516                unambiguous_columns[column] = table
-517
-518        return unambiguous_columns
-519
-520    @staticmethod
-521    def _find_unique_columns(columns):
-522        """
-523        Find the unique columns in a list of columns.
+503    def _get_all_source_columns(self):
+504        if self._source_columns is None:
+505            self._source_columns = {
+506                k: self.get_source_columns(k)
+507                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
+508            }
+509        return self._source_columns
+510
+511    def _get_unambiguous_columns(self, source_columns):
+512        """
+513        Find all the unambiguous columns in sources.
+514
+515        Args:
+516            source_columns (dict): Mapping of names to source columns
+517        Returns:
+518            dict: Mapping of column name to source name
+519        """
+520        if not source_columns:
+521            return {}
+522
+523        source_columns = list(source_columns.items())
 524
-525        Example:
-526            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
-527            ['a', 'c']
+525        first_table, first_columns = source_columns[0]
+526        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
+527        all_columns = set(unambiguous_columns)
 528
-529        This is necessary because duplicate column names are ambiguous.
-530        """
-531        counts = {}
-532        for column in columns:
-533            counts[column] = counts.get(column, 0) + 1
-534        return {column for column, count in counts.items() if count == 1}
+529        for table, columns in source_columns[1:]:
+530            unique = self._find_unique_columns(columns)
+531            ambiguous = set(all_columns).intersection(unique)
+532            all_columns.update(columns)
+533            for column in ambiguous:
+534                unambiguous_columns.pop(column, None)
+535            for column in unique.difference(ambiguous):
+536                unambiguous_columns[column] = table
+537
+538        return unambiguous_columns
+539
+540    @staticmethod
+541    def _find_unique_columns(columns):
+542        """
+543        Find the unique columns in a list of columns.
+544
+545        Example:
+546            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
+547            ['a', 'c']
+548
+549        This is necessary because duplicate column names are ambiguous.
+550        """
+551        counts = {}
+552        for column in columns:
+553            counts[column] = counts.get(column, 0) + 1
+554        return {column for column, count in counts.items() if count == 1}
 
@@ -888,18 +922,19 @@
- Resolver(scope, schema) + Resolver(scope, schema, infer_schema: bool = True)
-
408    def __init__(self, scope, schema):
-409        self.scope = scope
-410        self.schema = schema
-411        self._source_columns = None
-412        self._unambiguous_columns = None
-413        self._all_columns = None
+            
427    def __init__(self, scope, schema, infer_schema: bool = True):
+428        self.scope = scope
+429        self.schema = schema
+430        self._source_columns = None
+431        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
+432        self._all_columns = None
+433        self._infer_schema = infer_schema
 
@@ -917,47 +952,47 @@
-
415    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
-416        """
-417        Get the table for a column name.
-418
-419        Args:
-420            column_name: The column name to find the table for.
-421        Returns:
-422            The table name if it can be found/inferred.
-423        """
-424        if self._unambiguous_columns is None:
-425            self._unambiguous_columns = self._get_unambiguous_columns(
-426                self._get_all_source_columns()
-427            )
-428
-429        table_name = self._unambiguous_columns.get(column_name)
-430
-431        if not table_name:
-432            sources_without_schema = tuple(
-433                source
-434                for source, columns in self._get_all_source_columns().items()
-435                if not columns or "*" in columns
-436            )
-437            if len(sources_without_schema) == 1:
-438                table_name = sources_without_schema[0]
-439
-440        if table_name not in self.scope.selected_sources:
-441            return exp.to_identifier(table_name)
-442
-443        node, _ = self.scope.selected_sources.get(table_name)
-444
-445        if isinstance(node, exp.Subqueryable):
-446            while node and node.alias != table_name:
-447                node = node.parent
+            
435    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
+436        """
+437        Get the table for a column name.
+438
+439        Args:
+440            column_name: The column name to find the table for.
+441        Returns:
+442            The table name if it can be found/inferred.
+443        """
+444        if self._unambiguous_columns is None:
+445            self._unambiguous_columns = self._get_unambiguous_columns(
+446                self._get_all_source_columns()
+447            )
 448
-449        node_alias = node.args.get("alias")
-450        if node_alias:
-451            return node_alias.this
-452
-453        return exp.to_identifier(
-454            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
-455        )
+449        table_name = self._unambiguous_columns.get(column_name)
+450
+451        if not table_name and self._infer_schema:
+452            sources_without_schema = tuple(
+453                source
+454                for source, columns in self._get_all_source_columns().items()
+455                if not columns or "*" in columns
+456            )
+457            if len(sources_without_schema) == 1:
+458                table_name = sources_without_schema[0]
+459
+460        if table_name not in self.scope.selected_sources:
+461            return exp.to_identifier(table_name)
+462
+463        node, _ = self.scope.selected_sources.get(table_name)
+464
+465        if isinstance(node, exp.Subqueryable):
+466            while node and node.alias != table_name:
+467                node = node.parent
+468
+469        node_alias = node.args.get("alias")
+470        if node_alias:
+471            return exp.to_identifier(node_alias.this)
+472
+473        return exp.to_identifier(
+474            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
+475        )
 
@@ -1002,22 +1037,22 @@
-
466    def get_source_columns(self, name, only_visible=False):
-467        """Resolve the source columns for a given source `name`"""
-468        if name not in self.scope.sources:
-469            raise OptimizeError(f"Unknown table: {name}")
-470
-471        source = self.scope.sources[name]
-472
-473        # If referencing a table, return the columns from the schema
-474        if isinstance(source, exp.Table):
-475            return self.schema.column_names(source, only_visible)
-476
-477        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
-478            return source.expression.alias_column_names
-479
-480        # Otherwise, if referencing another scope, return that scope's named selects
-481        return source.expression.named_selects
+            
486    def get_source_columns(self, name, only_visible=False):
+487        """Resolve the source columns for a given source `name`"""
+488        if name not in self.scope.sources:
+489            raise OptimizeError(f"Unknown table: {name}")
+490
+491        source = self.scope.sources[name]
+492
+493        # If referencing a table, return the columns from the schema
+494        if isinstance(source, exp.Table):
+495            return self.schema.column_names(source, only_visible)
+496
+497        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+498            return source.expression.alias_column_names
+499
+500        # Otherwise, if referencing another scope, return that scope's named selects
+501        return source.expression.named_selects
 
-- cgit v1.2.3