sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 12from sqlglot.schema import Schema, ensure_schema 13 14 15def qualify_columns( 16 expression: exp.Expression, 17 schema: t.Dict | Schema, 18 expand_alias_refs: bool = True, 19 infer_schema: t.Optional[bool] = None, 20) -> exp.Expression: 21 """ 22 Rewrite sqlglot AST to have fully qualified columns. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"tbl": {"col": "INT"}} 27 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 28 >>> qualify_columns(expression, schema).sql() 29 'SELECT tbl.col AS col FROM tbl' 30 31 Args: 32 expression: Expression to qualify. 33 schema: Database schema. 34 expand_alias_refs: Whether or not to expand references to aliases. 35 infer_schema: Whether or not to infer the schema if missing. 36 37 Returns: 38 The qualified expression. 39 """ 40 schema = ensure_schema(schema) 41 infer_schema = schema.empty if infer_schema is None else infer_schema 42 43 for scope in traverse_scope(expression): 44 resolver = Resolver(scope, schema, infer_schema=infer_schema) 45 _pop_table_column_aliases(scope.ctes) 46 _pop_table_column_aliases(scope.derived_tables) 47 using_column_tables = _expand_using(scope, resolver) 48 49 if schema.empty and expand_alias_refs: 50 _expand_alias_refs(scope, resolver) 51 52 _qualify_columns(scope, resolver) 53 54 if not schema.empty and expand_alias_refs: 55 _expand_alias_refs(scope, resolver) 56 57 if not isinstance(scope.expression, exp.UDTF): 58 _expand_stars(scope, resolver, using_column_tables) 59 _qualify_outputs(scope) 60 _expand_group_by(scope) 61 _expand_order_by(scope, resolver) 62 63 return expression 64 65 66def validate_qualify_columns(expression: E) -> E: 67 """Raise an `OptimizeError` if any columns aren't qualified""" 68 unqualified_columns = [] 69 for scope in traverse_scope(expression): 70 if isinstance(scope.expression, exp.Select): 71 unqualified_columns.extend(scope.unqualified_columns) 72 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 73 column = scope.external_columns[0] 74 raise OptimizeError( 75 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 76 ) 77 78 if unqualified_columns: 79 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 80 return expression 81 82 83def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 84 """ 85 Remove table column aliases. 86 87 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 88 """ 89 for derived_table in derived_tables: 90 table_alias = derived_table.args.get("alias") 91 if table_alias: 92 table_alias.args.pop("columns", None) 93 94 95def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 96 joins = list(scope.find_all(exp.Join)) 97 names = {join.alias_or_name for join in joins} 98 ordered = [key for key in scope.selected_sources if key not in names] 99 100 # Mapping of automatically joined column names to an ordered set of source names (dict). 101 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 102 103 for join in joins: 104 using = join.args.get("using") 105 106 if not using: 107 continue 108 109 join_table = join.alias_or_name 110 111 columns = {} 112 113 for k in scope.selected_sources: 114 if k in ordered: 115 for column in resolver.get_source_columns(k): 116 if column not in columns: 117 columns[column] = k 118 119 source_table = ordered[-1] 120 ordered.append(join_table) 121 join_columns = resolver.get_source_columns(join_table) 122 conditions = [] 123 124 for identifier in using: 125 identifier = identifier.name 126 table = columns.get(identifier) 127 128 if not table or identifier not in join_columns: 129 if columns and join_columns: 130 raise OptimizeError(f"Cannot automatically join: {identifier}") 131 132 table = table or source_table 133 conditions.append( 134 exp.condition( 135 exp.EQ( 136 this=exp.column(identifier, table=table), 137 expression=exp.column(identifier, table=join_table), 138 ) 139 ) 140 ) 141 142 # Set all values in the dict to None, because we only care about the key ordering 143 tables = column_tables.setdefault(identifier, {}) 144 if table not in tables: 145 tables[table] = None 146 if join_table not in tables: 147 tables[join_table] = None 148 149 join.args.pop("using") 150 join.set("on", exp.and_(*conditions, copy=False)) 151 152 if column_tables: 153 for column in scope.columns: 154 if not column.table and column.name in column_tables: 155 tables = column_tables[column.name] 156 coalesce = [exp.column(column.name, table=table) for table in tables] 157 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 158 159 # Ensure selects keep their output name 160 if isinstance(column.parent, exp.Select): 161 replacement = alias(replacement, alias=column.name, copy=False) 162 163 scope.replace(column, replacement) 164 165 return column_tables 166 167 168def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 169 expression = scope.expression 170 171 if not isinstance(expression, exp.Select): 172 return 173 174 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 175 176 def replace_columns( 177 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 178 ) -> None: 179 if not node: 180 return 181 182 for column, *_ in walk_in_scope(node): 183 if not isinstance(column, exp.Column): 184 continue 185 table = resolver.get_table(column.name) if resolve_table and not column.table else None 186 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 187 double_agg = ( 188 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 189 if alias_expr 190 else False 191 ) 192 193 if table and (not alias_expr or double_agg): 194 column.set("table", table) 195 elif not column.table and alias_expr and not double_agg: 196 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 197 if literal_index: 198 column.replace(exp.Literal.number(i)) 199 else: 200 column.replace(alias_expr.copy()) 201 202 for i, projection in enumerate(scope.expression.selects): 203 replace_columns(projection) 204 205 if isinstance(projection, exp.Alias): 206 alias_to_expression[projection.alias] = (projection.this, i + 1) 207 208 replace_columns(expression.args.get("where")) 209 replace_columns(expression.args.get("group"), literal_index=True) 210 replace_columns(expression.args.get("having"), resolve_table=True) 211 replace_columns(expression.args.get("qualify"), resolve_table=True) 212 scope.clear_cache() 213 214 215def _expand_group_by(scope: Scope): 216 expression = scope.expression 217 group = expression.args.get("group") 218 if not group: 219 return 220 221 group.set("expressions", _expand_positional_references(scope, group.expressions)) 222 expression.set("group", group) 223 224 225def _expand_order_by(scope: Scope, resolver: Resolver): 226 order = scope.expression.args.get("order") 227 if not order: 228 return 229 230 ordereds = order.expressions 231 for ordered, new_expression in zip( 232 ordereds, 233 _expand_positional_references(scope, (o.this for o in ordereds)), 234 ): 235 for agg in ordered.find_all(exp.AggFunc): 236 for col in agg.find_all(exp.Column): 237 if not col.table: 238 col.set("table", resolver.get_table(col.name)) 239 240 ordered.set("this", new_expression) 241 242 if scope.expression.args.get("group"): 243 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 244 245 for ordered in ordereds: 246 ordered = ordered.this 247 248 ordered.replace( 249 exp.to_identifier(_select_by_pos(scope, ordered).alias) 250 if ordered.is_int 251 else selects.get(ordered, ordered) 252 ) 253 254 255def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: 256 new_nodes = [] 257 for node in expressions: 258 if node.is_int: 259 select = _select_by_pos(scope, t.cast(exp.Literal, node)).this 260 261 if isinstance(select, exp.Literal): 262 new_nodes.append(node) 263 else: 264 new_nodes.append(select.copy()) 265 scope.clear_cache() 266 else: 267 new_nodes.append(node) 268 269 return new_nodes 270 271 272def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 273 try: 274 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 275 except IndexError: 276 raise OptimizeError(f"Unknown output column: {node.name}") 277 278 279def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 280 """Disambiguate columns, ensuring each column specifies a source""" 281 for column in scope.columns: 282 column_table = column.table 283 column_name = column.name 284 285 if column_table and column_table in scope.sources: 286 source_columns = resolver.get_source_columns(column_table) 287 if source_columns and column_name not in source_columns and "*" not in source_columns: 288 raise OptimizeError(f"Unknown column: {column_name}") 289 290 if not column_table: 291 if scope.pivots and not column.find_ancestor(exp.Pivot): 292 # If the column is under the Pivot expression, we need to qualify it 293 # using the name of the pivoted source instead of the pivot's alias 294 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 295 continue 296 297 column_table = resolver.get_table(column_name) 298 299 # column_table can be a '' because bigquery unnest has no table alias 300 if column_table: 301 column.set("table", column_table) 302 elif column_table not in scope.sources and ( 303 not scope.parent or column_table not in scope.parent.sources 304 ): 305 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 306 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 307 308 root, *parts = column.parts 309 310 if root.name in scope.sources: 311 # struct is already qualified, but we still need to change the AST representation 312 column_table = root 313 root, *parts = parts 314 else: 315 column_table = resolver.get_table(root.name) 316 317 if column_table: 318 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 319 320 for pivot in scope.pivots: 321 for column in pivot.find_all(exp.Column): 322 if not column.table and column.name in resolver.all_columns: 323 column_table = resolver.get_table(column.name) 324 if column_table: 325 column.set("table", column_table) 326 327 328def _expand_stars( 329 scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any] 330) -> None: 331 """Expand stars to lists of column selections""" 332 333 new_selections = [] 334 except_columns: t.Dict[int, t.Set[str]] = {} 335 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 336 coalesced_columns = set() 337 338 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 339 pivot_columns = None 340 pivot_output_columns = None 341 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 342 343 has_pivoted_source = pivot and not pivot.args.get("unpivot") 344 if pivot and has_pivoted_source: 345 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 346 347 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 348 if not pivot_output_columns: 349 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 350 351 for expression in scope.expression.selects: 352 if isinstance(expression, exp.Star): 353 tables = list(scope.selected_sources) 354 _add_except_columns(expression, tables, except_columns) 355 _add_replace_columns(expression, tables, replace_columns) 356 elif expression.is_star: 357 tables = [expression.table] 358 _add_except_columns(expression.this, tables, except_columns) 359 _add_replace_columns(expression.this, tables, replace_columns) 360 else: 361 new_selections.append(expression) 362 continue 363 364 for table in tables: 365 if table not in scope.sources: 366 raise OptimizeError(f"Unknown table: {table}") 367 368 columns = resolver.get_source_columns(table, only_visible=True) 369 370 # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement 371 # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table 372 if resolver.schema.dialect == "bigquery": 373 columns = [ 374 name 375 for name in columns 376 if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE") 377 ] 378 379 if columns and "*" not in columns: 380 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 381 implicit_columns = [col for col in columns if col not in pivot_columns] 382 new_selections.extend( 383 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 384 for name in implicit_columns + pivot_output_columns 385 ) 386 continue 387 388 table_id = id(table) 389 for name in columns: 390 if name in using_column_tables and table in using_column_tables[name]: 391 if name in coalesced_columns: 392 continue 393 394 coalesced_columns.add(name) 395 tables = using_column_tables[name] 396 coalesce = [exp.column(name, table=table) for table in tables] 397 398 new_selections.append( 399 alias( 400 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 401 alias=name, 402 copy=False, 403 ) 404 ) 405 elif name not in except_columns.get(table_id, set()): 406 alias_ = replace_columns.get(table_id, {}).get(name, name) 407 column = exp.column(name, table=table) 408 new_selections.append( 409 alias(column, alias_, copy=False) if alias_ != name else column 410 ) 411 else: 412 return 413 414 # Ensures we don't overwrite the initial selections with an empty list 415 if new_selections: 416 scope.expression.set("expressions", new_selections) 417 418 419def _add_except_columns( 420 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 421) -> None: 422 except_ = expression.args.get("except") 423 424 if not except_: 425 return 426 427 columns = {e.name for e in except_} 428 429 for table in tables: 430 except_columns[id(table)] = columns 431 432 433def _add_replace_columns( 434 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 435) -> None: 436 replace = expression.args.get("replace") 437 438 if not replace: 439 return 440 441 columns = {e.this.name: e.alias for e in replace} 442 443 for table in tables: 444 replace_columns[id(table)] = columns 445 446 447def _qualify_outputs(scope: Scope): 448 """Ensure all output columns are aliased""" 449 new_selections = [] 450 451 for i, (selection, aliased_column) in enumerate( 452 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 453 ): 454 if isinstance(selection, exp.Subquery): 455 if not selection.output_name: 456 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 457 elif not isinstance(selection, exp.Alias) and not selection.is_star: 458 selection = alias( 459 selection, 460 alias=selection.output_name or f"_col_{i}", 461 ) 462 if aliased_column: 463 selection.set("alias", exp.to_identifier(aliased_column)) 464 465 new_selections.append(selection) 466 467 scope.expression.set("expressions", new_selections) 468 469 470def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 471 """Makes sure all identifiers that need to be quoted are quoted.""" 472 return expression.transform( 473 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 474 ) 475 476 477class Resolver: 478 """ 479 Helper for resolving columns. 480 481 This is a class so we can lazily load some things and easily share them across functions. 482 """ 483 484 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 485 self.scope = scope 486 self.schema = schema 487 self._source_columns = None 488 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 489 self._all_columns = None 490 self._infer_schema = infer_schema 491 492 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 493 """ 494 Get the table for a column name. 495 496 Args: 497 column_name: The column name to find the table for. 498 Returns: 499 The table name if it can be found/inferred. 500 """ 501 if self._unambiguous_columns is None: 502 self._unambiguous_columns = self._get_unambiguous_columns( 503 self._get_all_source_columns() 504 ) 505 506 table_name = self._unambiguous_columns.get(column_name) 507 508 if not table_name and self._infer_schema: 509 sources_without_schema = tuple( 510 source 511 for source, columns in self._get_all_source_columns().items() 512 if not columns or "*" in columns 513 ) 514 if len(sources_without_schema) == 1: 515 table_name = sources_without_schema[0] 516 517 if table_name not in self.scope.selected_sources: 518 return exp.to_identifier(table_name) 519 520 node, _ = self.scope.selected_sources.get(table_name) 521 522 if isinstance(node, exp.Subqueryable): 523 while node and node.alias != table_name: 524 node = node.parent 525 526 node_alias = node.args.get("alias") 527 if node_alias: 528 return exp.to_identifier(node_alias.this) 529 530 return exp.to_identifier(table_name) 531 532 @property 533 def all_columns(self): 534 """All available columns of all sources in this scope""" 535 if self._all_columns is None: 536 self._all_columns = { 537 column for columns in self._get_all_source_columns().values() for column in columns 538 } 539 return self._all_columns 540 541 def get_source_columns(self, name, only_visible=False): 542 """Resolve the source columns for a given source `name`""" 543 if name not in self.scope.sources: 544 raise OptimizeError(f"Unknown table: {name}") 545 546 source = self.scope.sources[name] 547 548 # If referencing a table, return the columns from the schema 549 if isinstance(source, exp.Table): 550 return self.schema.column_names(source, only_visible) 551 552 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 553 return source.expression.alias_column_names 554 555 # Otherwise, if referencing another scope, return that scope's named selects 556 return source.expression.named_selects 557 558 def _get_all_source_columns(self): 559 if self._source_columns is None: 560 self._source_columns = { 561 k: self.get_source_columns(k) 562 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 563 } 564 return self._source_columns 565 566 def _get_unambiguous_columns(self, source_columns): 567 """ 568 Find all the unambiguous columns in sources. 569 570 Args: 571 source_columns (dict): Mapping of names to source columns 572 Returns: 573 dict: Mapping of column name to source name 574 """ 575 if not source_columns: 576 return {} 577 578 source_columns = list(source_columns.items()) 579 580 first_table, first_columns = source_columns[0] 581 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 582 all_columns = set(unambiguous_columns) 583 584 for table, columns in source_columns[1:]: 585 unique = self._find_unique_columns(columns) 586 ambiguous = set(all_columns).intersection(unique) 587 all_columns.update(columns) 588 for column in ambiguous: 589 unambiguous_columns.pop(column, None) 590 for column in unique.difference(ambiguous): 591 unambiguous_columns[column] = table 592 593 return unambiguous_columns 594 595 @staticmethod 596 def _find_unique_columns(columns): 597 """ 598 Find the unique columns in a list of columns. 599 600 Example: 601 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 602 ['a', 'c'] 603 604 This is necessary because duplicate column names are ambiguous. 605 """ 606 counts = {} 607 for column in columns: 608 counts[column] = counts.get(column, 0) + 1 609 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: Expression to qualify. 34 schema: Database schema. 35 expand_alias_refs: Whether or not to expand references to aliases. 36 infer_schema: Whether or not to infer the schema if missing. 37 38 Returns: 39 The qualified expression. 40 """ 41 schema = ensure_schema(schema) 42 infer_schema = schema.empty if infer_schema is None else infer_schema 43 44 for scope in traverse_scope(expression): 45 resolver = Resolver(scope, schema, infer_schema=infer_schema) 46 _pop_table_column_aliases(scope.ctes) 47 _pop_table_column_aliases(scope.derived_tables) 48 using_column_tables = _expand_using(scope, resolver) 49 50 if schema.empty and expand_alias_refs: 51 _expand_alias_refs(scope, resolver) 52 53 _qualify_columns(scope, resolver) 54 55 if not schema.empty and expand_alias_refs: 56 _expand_alias_refs(scope, resolver) 57 58 if not isinstance(scope.expression, exp.UDTF): 59 _expand_stars(scope, resolver, using_column_tables) 60 _qualify_outputs(scope) 61 _expand_group_by(scope) 62 _expand_order_by(scope, resolver) 63 64 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
def
validate_qualify_columns(expression: ~E) -> ~E:
67def validate_qualify_columns(expression: E) -> E: 68 """Raise an `OptimizeError` if any columns aren't qualified""" 69 unqualified_columns = [] 70 for scope in traverse_scope(expression): 71 if isinstance(scope.expression, exp.Select): 72 unqualified_columns.extend(scope.unqualified_columns) 73 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 74 column = scope.external_columns[0] 75 raise OptimizeError( 76 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 77 ) 78 79 if unqualified_columns: 80 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 81 return expression
Raise an OptimizeError
if any columns aren't qualified
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
471def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 472 """Makes sure all identifiers that need to be quoted are quoted.""" 473 return expression.transform( 474 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 475 )
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
478class Resolver: 479 """ 480 Helper for resolving columns. 481 482 This is a class so we can lazily load some things and easily share them across functions. 483 """ 484 485 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 486 self.scope = scope 487 self.schema = schema 488 self._source_columns = None 489 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 490 self._all_columns = None 491 self._infer_schema = infer_schema 492 493 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 494 """ 495 Get the table for a column name. 496 497 Args: 498 column_name: The column name to find the table for. 499 Returns: 500 The table name if it can be found/inferred. 501 """ 502 if self._unambiguous_columns is None: 503 self._unambiguous_columns = self._get_unambiguous_columns( 504 self._get_all_source_columns() 505 ) 506 507 table_name = self._unambiguous_columns.get(column_name) 508 509 if not table_name and self._infer_schema: 510 sources_without_schema = tuple( 511 source 512 for source, columns in self._get_all_source_columns().items() 513 if not columns or "*" in columns 514 ) 515 if len(sources_without_schema) == 1: 516 table_name = sources_without_schema[0] 517 518 if table_name not in self.scope.selected_sources: 519 return exp.to_identifier(table_name) 520 521 node, _ = self.scope.selected_sources.get(table_name) 522 523 if isinstance(node, exp.Subqueryable): 524 while node and node.alias != table_name: 525 node = node.parent 526 527 node_alias = node.args.get("alias") 528 if node_alias: 529 return exp.to_identifier(node_alias.this) 530 531 return exp.to_identifier(table_name) 532 533 @property 534 def all_columns(self): 535 """All available columns of all sources in this scope""" 536 if self._all_columns is None: 537 self._all_columns = { 538 column for columns in self._get_all_source_columns().values() for column in columns 539 } 540 return self._all_columns 541 542 def get_source_columns(self, name, only_visible=False): 543 """Resolve the source columns for a given source `name`""" 544 if name not in self.scope.sources: 545 raise OptimizeError(f"Unknown table: {name}") 546 547 source = self.scope.sources[name] 548 549 # If referencing a table, return the columns from the schema 550 if isinstance(source, exp.Table): 551 return self.schema.column_names(source, only_visible) 552 553 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 554 return source.expression.alias_column_names 555 556 # Otherwise, if referencing another scope, return that scope's named selects 557 return source.expression.named_selects 558 559 def _get_all_source_columns(self): 560 if self._source_columns is None: 561 self._source_columns = { 562 k: self.get_source_columns(k) 563 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 564 } 565 return self._source_columns 566 567 def _get_unambiguous_columns(self, source_columns): 568 """ 569 Find all the unambiguous columns in sources. 570 571 Args: 572 source_columns (dict): Mapping of names to source columns 573 Returns: 574 dict: Mapping of column name to source name 575 """ 576 if not source_columns: 577 return {} 578 579 source_columns = list(source_columns.items()) 580 581 first_table, first_columns = source_columns[0] 582 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 583 all_columns = set(unambiguous_columns) 584 585 for table, columns in source_columns[1:]: 586 unique = self._find_unique_columns(columns) 587 ambiguous = set(all_columns).intersection(unique) 588 all_columns.update(columns) 589 for column in ambiguous: 590 unambiguous_columns.pop(column, None) 591 for column in unique.difference(ambiguous): 592 unambiguous_columns[column] = table 593 594 return unambiguous_columns 595 596 @staticmethod 597 def _find_unique_columns(columns): 598 """ 599 Find the unique columns in a list of columns. 600 601 Example: 602 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 603 ['a', 'c'] 604 605 This is necessary because duplicate column names are ambiguous. 606 """ 607 counts = {} 608 for column in columns: 609 counts[column] = counts.get(column, 0) + 1 610 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
493 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 494 """ 495 Get the table for a column name. 496 497 Args: 498 column_name: The column name to find the table for. 499 Returns: 500 The table name if it can be found/inferred. 501 """ 502 if self._unambiguous_columns is None: 503 self._unambiguous_columns = self._get_unambiguous_columns( 504 self._get_all_source_columns() 505 ) 506 507 table_name = self._unambiguous_columns.get(column_name) 508 509 if not table_name and self._infer_schema: 510 sources_without_schema = tuple( 511 source 512 for source, columns in self._get_all_source_columns().items() 513 if not columns or "*" in columns 514 ) 515 if len(sources_without_schema) == 1: 516 table_name = sources_without_schema[0] 517 518 if table_name not in self.scope.selected_sources: 519 return exp.to_identifier(table_name) 520 521 node, _ = self.scope.selected_sources.get(table_name) 522 523 if isinstance(node, exp.Subqueryable): 524 while node and node.alias != table_name: 525 node = node.parent 526 527 node_alias = node.args.get("alias") 528 if node_alias: 529 return exp.to_identifier(node_alias.this) 530 531 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
542 def get_source_columns(self, name, only_visible=False): 543 """Resolve the source columns for a given source `name`""" 544 if name not in self.scope.sources: 545 raise OptimizeError(f"Unknown table: {name}") 546 547 source = self.scope.sources[name] 548 549 # If referencing a table, return the columns from the schema 550 if isinstance(source, exp.Table): 551 return self.schema.column_names(source, only_visible) 552 553 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 554 return source.expression.alias_column_names 555 556 # Otherwise, if referencing another scope, return that scope's named selects 557 return source.expression.named_selects
Resolve the source columns for a given source name