sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 12from sqlglot.schema import Schema, ensure_schema 13 14 15def qualify_columns( 16 expression: exp.Expression, 17 schema: t.Dict | Schema, 18 expand_alias_refs: bool = True, 19 infer_schema: t.Optional[bool] = None, 20) -> exp.Expression: 21 """ 22 Rewrite sqlglot AST to have fully qualified columns. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"tbl": {"col": "INT"}} 27 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 28 >>> qualify_columns(expression, schema).sql() 29 'SELECT tbl.col AS col FROM tbl' 30 31 Args: 32 expression: Expression to qualify. 33 schema: Database schema. 34 expand_alias_refs: Whether or not to expand references to aliases. 35 infer_schema: Whether or not to infer the schema if missing. 36 37 Returns: 38 The qualified expression. 39 """ 40 schema = ensure_schema(schema) 41 infer_schema = schema.empty if infer_schema is None else infer_schema 42 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 43 44 for scope in traverse_scope(expression): 45 resolver = Resolver(scope, schema, infer_schema=infer_schema) 46 _pop_table_column_aliases(scope.ctes) 47 _pop_table_column_aliases(scope.derived_tables) 48 using_column_tables = _expand_using(scope, resolver) 49 50 if schema.empty and expand_alias_refs: 51 _expand_alias_refs(scope, resolver) 52 53 _qualify_columns(scope, resolver) 54 55 if not schema.empty and expand_alias_refs: 56 _expand_alias_refs(scope, resolver) 57 58 if not isinstance(scope.expression, exp.UDTF): 59 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 60 _qualify_outputs(scope) 61 _expand_group_by(scope) 62 _expand_order_by(scope, resolver) 63 64 return expression 65 66 67def validate_qualify_columns(expression: E) -> E: 68 """Raise an `OptimizeError` if any columns aren't qualified""" 69 unqualified_columns = [] 70 for scope in traverse_scope(expression): 71 if isinstance(scope.expression, exp.Select): 72 unqualified_columns.extend(scope.unqualified_columns) 73 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 74 column = scope.external_columns[0] 75 raise OptimizeError( 76 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 77 ) 78 79 if unqualified_columns: 80 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 81 return expression 82 83 84def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 85 """ 86 Remove table column aliases. 87 88 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 89 """ 90 for derived_table in derived_tables: 91 table_alias = derived_table.args.get("alias") 92 if table_alias: 93 table_alias.args.pop("columns", None) 94 95 96def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 97 joins = list(scope.find_all(exp.Join)) 98 names = {join.alias_or_name for join in joins} 99 ordered = [key for key in scope.selected_sources if key not in names] 100 101 # Mapping of automatically joined column names to an ordered set of source names (dict). 102 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 103 104 for join in joins: 105 using = join.args.get("using") 106 107 if not using: 108 continue 109 110 join_table = join.alias_or_name 111 112 columns = {} 113 114 for k in scope.selected_sources: 115 if k in ordered: 116 for column in resolver.get_source_columns(k): 117 if column not in columns: 118 columns[column] = k 119 120 source_table = ordered[-1] 121 ordered.append(join_table) 122 join_columns = resolver.get_source_columns(join_table) 123 conditions = [] 124 125 for identifier in using: 126 identifier = identifier.name 127 table = columns.get(identifier) 128 129 if not table or identifier not in join_columns: 130 if columns and join_columns: 131 raise OptimizeError(f"Cannot automatically join: {identifier}") 132 133 table = table or source_table 134 conditions.append( 135 exp.condition( 136 exp.EQ( 137 this=exp.column(identifier, table=table), 138 expression=exp.column(identifier, table=join_table), 139 ) 140 ) 141 ) 142 143 # Set all values in the dict to None, because we only care about the key ordering 144 tables = column_tables.setdefault(identifier, {}) 145 if table not in tables: 146 tables[table] = None 147 if join_table not in tables: 148 tables[join_table] = None 149 150 join.args.pop("using") 151 join.set("on", exp.and_(*conditions, copy=False)) 152 153 if column_tables: 154 for column in scope.columns: 155 if not column.table and column.name in column_tables: 156 tables = column_tables[column.name] 157 coalesce = [exp.column(column.name, table=table) for table in tables] 158 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 159 160 # Ensure selects keep their output name 161 if isinstance(column.parent, exp.Select): 162 replacement = alias(replacement, alias=column.name, copy=False) 163 164 scope.replace(column, replacement) 165 166 return column_tables 167 168 169def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 170 expression = scope.expression 171 172 if not isinstance(expression, exp.Select): 173 return 174 175 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 176 177 def replace_columns( 178 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 179 ) -> None: 180 if not node: 181 return 182 183 for column, *_ in walk_in_scope(node): 184 if not isinstance(column, exp.Column): 185 continue 186 table = resolver.get_table(column.name) if resolve_table and not column.table else None 187 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 188 double_agg = ( 189 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 190 if alias_expr 191 else False 192 ) 193 194 if table and (not alias_expr or double_agg): 195 column.set("table", table) 196 elif not column.table and alias_expr and not double_agg: 197 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 198 if literal_index: 199 column.replace(exp.Literal.number(i)) 200 else: 201 column.replace(alias_expr.copy()) 202 203 for i, projection in enumerate(scope.expression.selects): 204 replace_columns(projection) 205 206 if isinstance(projection, exp.Alias): 207 alias_to_expression[projection.alias] = (projection.this, i + 1) 208 209 replace_columns(expression.args.get("where")) 210 replace_columns(expression.args.get("group"), literal_index=True) 211 replace_columns(expression.args.get("having"), resolve_table=True) 212 replace_columns(expression.args.get("qualify"), resolve_table=True) 213 scope.clear_cache() 214 215 216def _expand_group_by(scope: Scope): 217 expression = scope.expression 218 group = expression.args.get("group") 219 if not group: 220 return 221 222 group.set("expressions", _expand_positional_references(scope, group.expressions)) 223 expression.set("group", group) 224 225 226def _expand_order_by(scope: Scope, resolver: Resolver): 227 order = scope.expression.args.get("order") 228 if not order: 229 return 230 231 ordereds = order.expressions 232 for ordered, new_expression in zip( 233 ordereds, 234 _expand_positional_references(scope, (o.this for o in ordereds)), 235 ): 236 for agg in ordered.find_all(exp.AggFunc): 237 for col in agg.find_all(exp.Column): 238 if not col.table: 239 col.set("table", resolver.get_table(col.name)) 240 241 ordered.set("this", new_expression) 242 243 if scope.expression.args.get("group"): 244 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 245 246 for ordered in ordereds: 247 ordered = ordered.this 248 249 ordered.replace( 250 exp.to_identifier(_select_by_pos(scope, ordered).alias) 251 if ordered.is_int 252 else selects.get(ordered, ordered) 253 ) 254 255 256def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: 257 new_nodes = [] 258 for node in expressions: 259 if node.is_int: 260 select = _select_by_pos(scope, t.cast(exp.Literal, node)).this 261 262 if isinstance(select, exp.Literal): 263 new_nodes.append(node) 264 else: 265 new_nodes.append(select.copy()) 266 scope.clear_cache() 267 else: 268 new_nodes.append(node) 269 270 return new_nodes 271 272 273def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 274 try: 275 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 276 except IndexError: 277 raise OptimizeError(f"Unknown output column: {node.name}") 278 279 280def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 281 """Disambiguate columns, ensuring each column specifies a source""" 282 for column in scope.columns: 283 column_table = column.table 284 column_name = column.name 285 286 if column_table and column_table in scope.sources: 287 source_columns = resolver.get_source_columns(column_table) 288 if source_columns and column_name not in source_columns and "*" not in source_columns: 289 raise OptimizeError(f"Unknown column: {column_name}") 290 291 if not column_table: 292 if scope.pivots and not column.find_ancestor(exp.Pivot): 293 # If the column is under the Pivot expression, we need to qualify it 294 # using the name of the pivoted source instead of the pivot's alias 295 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 296 continue 297 298 column_table = resolver.get_table(column_name) 299 300 # column_table can be a '' because bigquery unnest has no table alias 301 if column_table: 302 column.set("table", column_table) 303 elif column_table not in scope.sources and ( 304 not scope.parent or column_table not in scope.parent.sources 305 ): 306 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 307 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 308 309 root, *parts = column.parts 310 311 if root.name in scope.sources: 312 # struct is already qualified, but we still need to change the AST representation 313 column_table = root 314 root, *parts = parts 315 else: 316 column_table = resolver.get_table(root.name) 317 318 if column_table: 319 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 320 321 for pivot in scope.pivots: 322 for column in pivot.find_all(exp.Column): 323 if not column.table and column.name in resolver.all_columns: 324 column_table = resolver.get_table(column.name) 325 if column_table: 326 column.set("table", column_table) 327 328 329def _expand_stars( 330 scope: Scope, 331 resolver: Resolver, 332 using_column_tables: t.Dict[str, t.Any], 333 pseudocolumns: t.Set[str], 334) -> None: 335 """Expand stars to lists of column selections""" 336 337 new_selections = [] 338 except_columns: t.Dict[int, t.Set[str]] = {} 339 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 340 coalesced_columns = set() 341 342 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 343 pivot_columns = None 344 pivot_output_columns = None 345 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 346 347 has_pivoted_source = pivot and not pivot.args.get("unpivot") 348 if pivot and has_pivoted_source: 349 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 350 351 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 352 if not pivot_output_columns: 353 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 354 355 for expression in scope.expression.selects: 356 if isinstance(expression, exp.Star): 357 tables = list(scope.selected_sources) 358 _add_except_columns(expression, tables, except_columns) 359 _add_replace_columns(expression, tables, replace_columns) 360 elif expression.is_star: 361 tables = [expression.table] 362 _add_except_columns(expression.this, tables, except_columns) 363 _add_replace_columns(expression.this, tables, replace_columns) 364 else: 365 new_selections.append(expression) 366 continue 367 368 for table in tables: 369 if table not in scope.sources: 370 raise OptimizeError(f"Unknown table: {table}") 371 372 columns = resolver.get_source_columns(table, only_visible=True) 373 374 if pseudocolumns: 375 columns = [name for name in columns if name.upper() not in pseudocolumns] 376 377 if columns and "*" not in columns: 378 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 379 implicit_columns = [col for col in columns if col not in pivot_columns] 380 new_selections.extend( 381 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 382 for name in implicit_columns + pivot_output_columns 383 ) 384 continue 385 386 table_id = id(table) 387 for name in columns: 388 if name in using_column_tables and table in using_column_tables[name]: 389 if name in coalesced_columns: 390 continue 391 392 coalesced_columns.add(name) 393 tables = using_column_tables[name] 394 coalesce = [exp.column(name, table=table) for table in tables] 395 396 new_selections.append( 397 alias( 398 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 399 alias=name, 400 copy=False, 401 ) 402 ) 403 elif name not in except_columns.get(table_id, set()): 404 alias_ = replace_columns.get(table_id, {}).get(name, name) 405 column = exp.column(name, table=table) 406 new_selections.append( 407 alias(column, alias_, copy=False) if alias_ != name else column 408 ) 409 else: 410 return 411 412 # Ensures we don't overwrite the initial selections with an empty list 413 if new_selections: 414 scope.expression.set("expressions", new_selections) 415 416 417def _add_except_columns( 418 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 419) -> None: 420 except_ = expression.args.get("except") 421 422 if not except_: 423 return 424 425 columns = {e.name for e in except_} 426 427 for table in tables: 428 except_columns[id(table)] = columns 429 430 431def _add_replace_columns( 432 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 433) -> None: 434 replace = expression.args.get("replace") 435 436 if not replace: 437 return 438 439 columns = {e.this.name: e.alias for e in replace} 440 441 for table in tables: 442 replace_columns[id(table)] = columns 443 444 445def _qualify_outputs(scope: Scope): 446 """Ensure all output columns are aliased""" 447 new_selections = [] 448 449 for i, (selection, aliased_column) in enumerate( 450 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 451 ): 452 if isinstance(selection, exp.Subquery): 453 if not selection.output_name: 454 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 455 elif not isinstance(selection, exp.Alias) and not selection.is_star: 456 selection = alias( 457 selection, 458 alias=selection.output_name or f"_col_{i}", 459 ) 460 if aliased_column: 461 selection.set("alias", exp.to_identifier(aliased_column)) 462 463 new_selections.append(selection) 464 465 scope.expression.set("expressions", new_selections) 466 467 468def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 469 """Makes sure all identifiers that need to be quoted are quoted.""" 470 return expression.transform( 471 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 472 ) 473 474 475class Resolver: 476 """ 477 Helper for resolving columns. 478 479 This is a class so we can lazily load some things and easily share them across functions. 480 """ 481 482 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 483 self.scope = scope 484 self.schema = schema 485 self._source_columns = None 486 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 487 self._all_columns = None 488 self._infer_schema = infer_schema 489 490 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 491 """ 492 Get the table for a column name. 493 494 Args: 495 column_name: The column name to find the table for. 496 Returns: 497 The table name if it can be found/inferred. 498 """ 499 if self._unambiguous_columns is None: 500 self._unambiguous_columns = self._get_unambiguous_columns( 501 self._get_all_source_columns() 502 ) 503 504 table_name = self._unambiguous_columns.get(column_name) 505 506 if not table_name and self._infer_schema: 507 sources_without_schema = tuple( 508 source 509 for source, columns in self._get_all_source_columns().items() 510 if not columns or "*" in columns 511 ) 512 if len(sources_without_schema) == 1: 513 table_name = sources_without_schema[0] 514 515 if table_name not in self.scope.selected_sources: 516 return exp.to_identifier(table_name) 517 518 node, _ = self.scope.selected_sources.get(table_name) 519 520 if isinstance(node, exp.Subqueryable): 521 while node and node.alias != table_name: 522 node = node.parent 523 524 node_alias = node.args.get("alias") 525 if node_alias: 526 return exp.to_identifier(node_alias.this) 527 528 return exp.to_identifier(table_name) 529 530 @property 531 def all_columns(self): 532 """All available columns of all sources in this scope""" 533 if self._all_columns is None: 534 self._all_columns = { 535 column for columns in self._get_all_source_columns().values() for column in columns 536 } 537 return self._all_columns 538 539 def get_source_columns(self, name, only_visible=False): 540 """Resolve the source columns for a given source `name`""" 541 if name not in self.scope.sources: 542 raise OptimizeError(f"Unknown table: {name}") 543 544 source = self.scope.sources[name] 545 546 # If referencing a table, return the columns from the schema 547 if isinstance(source, exp.Table): 548 return self.schema.column_names(source, only_visible) 549 550 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 551 return source.expression.alias_column_names 552 553 # Otherwise, if referencing another scope, return that scope's named selects 554 return source.expression.named_selects 555 556 def _get_all_source_columns(self): 557 if self._source_columns is None: 558 self._source_columns = { 559 k: self.get_source_columns(k) 560 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 561 } 562 return self._source_columns 563 564 def _get_unambiguous_columns(self, source_columns): 565 """ 566 Find all the unambiguous columns in sources. 567 568 Args: 569 source_columns (dict): Mapping of names to source columns 570 Returns: 571 dict: Mapping of column name to source name 572 """ 573 if not source_columns: 574 return {} 575 576 source_columns = list(source_columns.items()) 577 578 first_table, first_columns = source_columns[0] 579 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 580 all_columns = set(unambiguous_columns) 581 582 for table, columns in source_columns[1:]: 583 unique = self._find_unique_columns(columns) 584 ambiguous = set(all_columns).intersection(unique) 585 all_columns.update(columns) 586 for column in ambiguous: 587 unambiguous_columns.pop(column, None) 588 for column in unique.difference(ambiguous): 589 unambiguous_columns[column] = table 590 591 return unambiguous_columns 592 593 @staticmethod 594 def _find_unique_columns(columns): 595 """ 596 Find the unique columns in a list of columns. 597 598 Example: 599 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 600 ['a', 'c'] 601 602 This is necessary because duplicate column names are ambiguous. 603 """ 604 counts = {} 605 for column in columns: 606 counts[column] = counts.get(column, 0) + 1 607 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: Expression to qualify. 34 schema: Database schema. 35 expand_alias_refs: Whether or not to expand references to aliases. 36 infer_schema: Whether or not to infer the schema if missing. 37 38 Returns: 39 The qualified expression. 40 """ 41 schema = ensure_schema(schema) 42 infer_schema = schema.empty if infer_schema is None else infer_schema 43 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 44 45 for scope in traverse_scope(expression): 46 resolver = Resolver(scope, schema, infer_schema=infer_schema) 47 _pop_table_column_aliases(scope.ctes) 48 _pop_table_column_aliases(scope.derived_tables) 49 using_column_tables = _expand_using(scope, resolver) 50 51 if schema.empty and expand_alias_refs: 52 _expand_alias_refs(scope, resolver) 53 54 _qualify_columns(scope, resolver) 55 56 if not schema.empty and expand_alias_refs: 57 _expand_alias_refs(scope, resolver) 58 59 if not isinstance(scope.expression, exp.UDTF): 60 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 61 _qualify_outputs(scope) 62 _expand_group_by(scope) 63 _expand_order_by(scope, resolver) 64 65 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
def
validate_qualify_columns(expression: ~E) -> ~E:
68def validate_qualify_columns(expression: E) -> E: 69 """Raise an `OptimizeError` if any columns aren't qualified""" 70 unqualified_columns = [] 71 for scope in traverse_scope(expression): 72 if isinstance(scope.expression, exp.Select): 73 unqualified_columns.extend(scope.unqualified_columns) 74 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 75 column = scope.external_columns[0] 76 raise OptimizeError( 77 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 78 ) 79 80 if unqualified_columns: 81 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 82 return expression
Raise an OptimizeError
if any columns aren't qualified
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
469def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 470 """Makes sure all identifiers that need to be quoted are quoted.""" 471 return expression.transform( 472 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 473 )
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
476class Resolver: 477 """ 478 Helper for resolving columns. 479 480 This is a class so we can lazily load some things and easily share them across functions. 481 """ 482 483 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 484 self.scope = scope 485 self.schema = schema 486 self._source_columns = None 487 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 488 self._all_columns = None 489 self._infer_schema = infer_schema 490 491 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 492 """ 493 Get the table for a column name. 494 495 Args: 496 column_name: The column name to find the table for. 497 Returns: 498 The table name if it can be found/inferred. 499 """ 500 if self._unambiguous_columns is None: 501 self._unambiguous_columns = self._get_unambiguous_columns( 502 self._get_all_source_columns() 503 ) 504 505 table_name = self._unambiguous_columns.get(column_name) 506 507 if not table_name and self._infer_schema: 508 sources_without_schema = tuple( 509 source 510 for source, columns in self._get_all_source_columns().items() 511 if not columns or "*" in columns 512 ) 513 if len(sources_without_schema) == 1: 514 table_name = sources_without_schema[0] 515 516 if table_name not in self.scope.selected_sources: 517 return exp.to_identifier(table_name) 518 519 node, _ = self.scope.selected_sources.get(table_name) 520 521 if isinstance(node, exp.Subqueryable): 522 while node and node.alias != table_name: 523 node = node.parent 524 525 node_alias = node.args.get("alias") 526 if node_alias: 527 return exp.to_identifier(node_alias.this) 528 529 return exp.to_identifier(table_name) 530 531 @property 532 def all_columns(self): 533 """All available columns of all sources in this scope""" 534 if self._all_columns is None: 535 self._all_columns = { 536 column for columns in self._get_all_source_columns().values() for column in columns 537 } 538 return self._all_columns 539 540 def get_source_columns(self, name, only_visible=False): 541 """Resolve the source columns for a given source `name`""" 542 if name not in self.scope.sources: 543 raise OptimizeError(f"Unknown table: {name}") 544 545 source = self.scope.sources[name] 546 547 # If referencing a table, return the columns from the schema 548 if isinstance(source, exp.Table): 549 return self.schema.column_names(source, only_visible) 550 551 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 552 return source.expression.alias_column_names 553 554 # Otherwise, if referencing another scope, return that scope's named selects 555 return source.expression.named_selects 556 557 def _get_all_source_columns(self): 558 if self._source_columns is None: 559 self._source_columns = { 560 k: self.get_source_columns(k) 561 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 562 } 563 return self._source_columns 564 565 def _get_unambiguous_columns(self, source_columns): 566 """ 567 Find all the unambiguous columns in sources. 568 569 Args: 570 source_columns (dict): Mapping of names to source columns 571 Returns: 572 dict: Mapping of column name to source name 573 """ 574 if not source_columns: 575 return {} 576 577 source_columns = list(source_columns.items()) 578 579 first_table, first_columns = source_columns[0] 580 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 581 all_columns = set(unambiguous_columns) 582 583 for table, columns in source_columns[1:]: 584 unique = self._find_unique_columns(columns) 585 ambiguous = set(all_columns).intersection(unique) 586 all_columns.update(columns) 587 for column in ambiguous: 588 unambiguous_columns.pop(column, None) 589 for column in unique.difference(ambiguous): 590 unambiguous_columns[column] = table 591 592 return unambiguous_columns 593 594 @staticmethod 595 def _find_unique_columns(columns): 596 """ 597 Find the unique columns in a list of columns. 598 599 Example: 600 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 601 ['a', 'c'] 602 603 This is necessary because duplicate column names are ambiguous. 604 """ 605 counts = {} 606 for column in columns: 607 counts[column] = counts.get(column, 0) + 1 608 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
491 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 492 """ 493 Get the table for a column name. 494 495 Args: 496 column_name: The column name to find the table for. 497 Returns: 498 The table name if it can be found/inferred. 499 """ 500 if self._unambiguous_columns is None: 501 self._unambiguous_columns = self._get_unambiguous_columns( 502 self._get_all_source_columns() 503 ) 504 505 table_name = self._unambiguous_columns.get(column_name) 506 507 if not table_name and self._infer_schema: 508 sources_without_schema = tuple( 509 source 510 for source, columns in self._get_all_source_columns().items() 511 if not columns or "*" in columns 512 ) 513 if len(sources_without_schema) == 1: 514 table_name = sources_without_schema[0] 515 516 if table_name not in self.scope.selected_sources: 517 return exp.to_identifier(table_name) 518 519 node, _ = self.scope.selected_sources.get(table_name) 520 521 if isinstance(node, exp.Subqueryable): 522 while node and node.alias != table_name: 523 node = node.parent 524 525 node_alias = node.args.get("alias") 526 if node_alias: 527 return exp.to_identifier(node_alias.this) 528 529 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
540 def get_source_columns(self, name, only_visible=False): 541 """Resolve the source columns for a given source `name`""" 542 if name not in self.scope.sources: 543 raise OptimizeError(f"Unknown table: {name}") 544 545 source = self.scope.sources[name] 546 547 # If referencing a table, return the columns from the schema 548 if isinstance(source, exp.Table): 549 return self.schema.column_names(source, only_visible) 550 551 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 552 return source.expression.alias_column_names 553 554 # Otherwise, if referencing another scope, return that scope's named selects 555 return source.expression.named_selects
Resolve the source columns for a given source name