sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 12from sqlglot.schema import Schema, ensure_schema 13 14 15def qualify_columns( 16 expression: exp.Expression, 17 schema: t.Dict | Schema, 18 expand_alias_refs: bool = True, 19 infer_schema: t.Optional[bool] = None, 20) -> exp.Expression: 21 """ 22 Rewrite sqlglot AST to have fully qualified columns. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"tbl": {"col": "INT"}} 27 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 28 >>> qualify_columns(expression, schema).sql() 29 'SELECT tbl.col AS col FROM tbl' 30 31 Args: 32 expression: expression to qualify 33 schema: Database schema 34 expand_alias_refs: whether or not to expand references to aliases 35 infer_schema: whether or not to infer the schema if missing 36 Returns: 37 sqlglot.Expression: qualified expression 38 """ 39 schema = ensure_schema(schema) 40 infer_schema = schema.empty if infer_schema is None else infer_schema 41 42 for scope in traverse_scope(expression): 43 resolver = Resolver(scope, schema, infer_schema=infer_schema) 44 _pop_table_column_aliases(scope.ctes) 45 _pop_table_column_aliases(scope.derived_tables) 46 using_column_tables = _expand_using(scope, resolver) 47 48 if schema.empty and expand_alias_refs: 49 _expand_alias_refs(scope, resolver) 50 51 _qualify_columns(scope, resolver) 52 53 if not schema.empty and expand_alias_refs: 54 _expand_alias_refs(scope, resolver) 55 56 if not isinstance(scope.expression, exp.UDTF): 57 _expand_stars(scope, resolver, using_column_tables) 58 _qualify_outputs(scope) 59 _expand_group_by(scope, resolver) 60 _expand_order_by(scope) 61 62 return expression 63 64 65def validate_qualify_columns(expression): 66 """Raise an `OptimizeError` if any columns aren't qualified""" 67 unqualified_columns = [] 68 for scope in traverse_scope(expression): 69 if isinstance(scope.expression, exp.Select): 70 unqualified_columns.extend(scope.unqualified_columns) 71 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 72 column = scope.external_columns[0] 73 raise OptimizeError( 74 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 75 ) 76 77 if unqualified_columns: 78 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 79 return expression 80 81 82def _pop_table_column_aliases(derived_tables): 83 """ 84 Remove table column aliases. 85 86 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 87 """ 88 for derived_table in derived_tables: 89 table_alias = derived_table.args.get("alias") 90 if table_alias: 91 table_alias.args.pop("columns", None) 92 93 94def _expand_using(scope, resolver): 95 joins = list(scope.find_all(exp.Join)) 96 names = {join.alias_or_name for join in joins} 97 ordered = [key for key in scope.selected_sources if key not in names] 98 99 # Mapping of automatically joined column names to an ordered set of source names (dict). 100 column_tables = {} 101 102 for join in joins: 103 using = join.args.get("using") 104 105 if not using: 106 continue 107 108 join_table = join.alias_or_name 109 110 columns = {} 111 112 for k in scope.selected_sources: 113 if k in ordered: 114 for column in resolver.get_source_columns(k): 115 if column not in columns: 116 columns[column] = k 117 118 source_table = ordered[-1] 119 ordered.append(join_table) 120 join_columns = resolver.get_source_columns(join_table) 121 conditions = [] 122 123 for identifier in using: 124 identifier = identifier.name 125 table = columns.get(identifier) 126 127 if not table or identifier not in join_columns: 128 if columns and join_columns: 129 raise OptimizeError(f"Cannot automatically join: {identifier}") 130 131 table = table or source_table 132 conditions.append( 133 exp.condition( 134 exp.EQ( 135 this=exp.column(identifier, table=table), 136 expression=exp.column(identifier, table=join_table), 137 ) 138 ) 139 ) 140 141 # Set all values in the dict to None, because we only care about the key ordering 142 tables = column_tables.setdefault(identifier, {}) 143 if table not in tables: 144 tables[table] = None 145 if join_table not in tables: 146 tables[join_table] = None 147 148 join.args.pop("using") 149 join.set("on", exp.and_(*conditions, copy=False)) 150 151 if column_tables: 152 for column in scope.columns: 153 if not column.table and column.name in column_tables: 154 tables = column_tables[column.name] 155 coalesce = [exp.column(column.name, table=table) for table in tables] 156 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 157 158 # Ensure selects keep their output name 159 if isinstance(column.parent, exp.Select): 160 replacement = alias(replacement, alias=column.name, copy=False) 161 162 scope.replace(column, replacement) 163 164 return column_tables 165 166 167def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 168 expression = scope.expression 169 170 if not isinstance(expression, exp.Select): 171 return 172 173 alias_to_expression: t.Dict[str, exp.Expression] = {} 174 175 def replace_columns( 176 node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False 177 ): 178 if not node: 179 return 180 181 for column, *_ in walk_in_scope(node): 182 if not isinstance(column, exp.Column): 183 continue 184 table = resolver.get_table(column.name) if resolve_agg and not column.table else None 185 if table and column.find_ancestor(exp.AggFunc): 186 column.set("table", table) 187 elif expand and not column.table and column.name in alias_to_expression: 188 column.replace(alias_to_expression[column.name].copy()) 189 190 for projection in scope.selects: 191 replace_columns(projection) 192 193 if isinstance(projection, exp.Alias): 194 alias_to_expression[projection.alias] = projection.this 195 196 replace_columns(expression.args.get("where")) 197 replace_columns(expression.args.get("group")) 198 replace_columns(expression.args.get("having"), resolve_agg=True) 199 replace_columns(expression.args.get("qualify"), resolve_agg=True) 200 replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) 201 scope.clear_cache() 202 203 204def _expand_group_by(scope, resolver): 205 group = scope.expression.args.get("group") 206 if not group: 207 return 208 209 group.set("expressions", _expand_positional_references(scope, group.expressions)) 210 scope.expression.set("group", group) 211 212 213def _expand_order_by(scope): 214 order = scope.expression.args.get("order") 215 if not order: 216 return 217 218 ordereds = order.expressions 219 for ordered, new_expression in zip( 220 ordereds, 221 _expand_positional_references(scope, (o.this for o in ordereds)), 222 ): 223 ordered.set("this", new_expression) 224 225 226def _expand_positional_references(scope, expressions): 227 new_nodes = [] 228 for node in expressions: 229 if node.is_int: 230 try: 231 select = scope.selects[int(node.name) - 1] 232 except IndexError: 233 raise OptimizeError(f"Unknown output column: {node.name}") 234 if isinstance(select, exp.Alias): 235 select = select.this 236 new_nodes.append(select.copy()) 237 scope.clear_cache() 238 else: 239 new_nodes.append(node) 240 241 return new_nodes 242 243 244def _qualify_columns(scope, resolver): 245 """Disambiguate columns, ensuring each column specifies a source""" 246 for column in scope.columns: 247 column_table = column.table 248 column_name = column.name 249 250 if column_table and column_table in scope.sources: 251 source_columns = resolver.get_source_columns(column_table) 252 if source_columns and column_name not in source_columns and "*" not in source_columns: 253 raise OptimizeError(f"Unknown column: {column_name}") 254 255 if not column_table: 256 if scope.pivots and not column.find_ancestor(exp.Pivot): 257 # If the column is under the Pivot expression, we need to qualify it 258 # using the name of the pivoted source instead of the pivot's alias 259 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 260 continue 261 262 column_table = resolver.get_table(column_name) 263 264 # column_table can be a '' because bigquery unnest has no table alias 265 if column_table: 266 column.set("table", column_table) 267 elif column_table not in scope.sources and ( 268 not scope.parent or column_table not in scope.parent.sources 269 ): 270 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 271 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 272 273 root, *parts = column.parts 274 275 if root.name in scope.sources: 276 # struct is already qualified, but we still need to change the AST representation 277 column_table = root 278 root, *parts = parts 279 else: 280 column_table = resolver.get_table(root.name) 281 282 if column_table: 283 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 284 285 for pivot in scope.pivots: 286 for column in pivot.find_all(exp.Column): 287 if not column.table and column.name in resolver.all_columns: 288 column_table = resolver.get_table(column.name) 289 if column_table: 290 column.set("table", column_table) 291 292 293def _expand_stars(scope, resolver, using_column_tables): 294 """Expand stars to lists of column selections""" 295 296 new_selections = [] 297 except_columns = {} 298 replace_columns = {} 299 coalesced_columns = set() 300 301 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 302 pivot_columns = None 303 pivot_output_columns = None 304 pivot = seq_get(scope.pivots, 0) 305 306 has_pivoted_source = pivot and not pivot.args.get("unpivot") 307 if has_pivoted_source: 308 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 309 310 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 311 if not pivot_output_columns: 312 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 313 314 for expression in scope.selects: 315 if isinstance(expression, exp.Star): 316 tables = list(scope.selected_sources) 317 _add_except_columns(expression, tables, except_columns) 318 _add_replace_columns(expression, tables, replace_columns) 319 elif expression.is_star: 320 tables = [expression.table] 321 _add_except_columns(expression.this, tables, except_columns) 322 _add_replace_columns(expression.this, tables, replace_columns) 323 else: 324 new_selections.append(expression) 325 continue 326 327 for table in tables: 328 if table not in scope.sources: 329 raise OptimizeError(f"Unknown table: {table}") 330 331 columns = resolver.get_source_columns(table, only_visible=True) 332 333 if columns and "*" not in columns: 334 if has_pivoted_source: 335 implicit_columns = [col for col in columns if col not in pivot_columns] 336 new_selections.extend( 337 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 338 for name in implicit_columns + pivot_output_columns 339 ) 340 continue 341 342 table_id = id(table) 343 for name in columns: 344 if name in using_column_tables and table in using_column_tables[name]: 345 if name in coalesced_columns: 346 continue 347 348 coalesced_columns.add(name) 349 tables = using_column_tables[name] 350 coalesce = [exp.column(name, table=table) for table in tables] 351 352 new_selections.append( 353 alias( 354 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 355 alias=name, 356 copy=False, 357 ) 358 ) 359 elif name not in except_columns.get(table_id, set()): 360 alias_ = replace_columns.get(table_id, {}).get(name, name) 361 column = exp.column(name, table=table) 362 new_selections.append( 363 alias(column, alias_, copy=False) if alias_ != name else column 364 ) 365 else: 366 return 367 368 scope.expression.set("expressions", new_selections) 369 370 371def _add_except_columns(expression, tables, except_columns): 372 except_ = expression.args.get("except") 373 374 if not except_: 375 return 376 377 columns = {e.name for e in except_} 378 379 for table in tables: 380 except_columns[id(table)] = columns 381 382 383def _add_replace_columns(expression, tables, replace_columns): 384 replace = expression.args.get("replace") 385 386 if not replace: 387 return 388 389 columns = {e.this.name: e.alias for e in replace} 390 391 for table in tables: 392 replace_columns[id(table)] = columns 393 394 395def _qualify_outputs(scope): 396 """Ensure all output columns are aliased""" 397 new_selections = [] 398 399 for i, (selection, aliased_column) in enumerate( 400 itertools.zip_longest(scope.selects, scope.outer_column_list) 401 ): 402 if isinstance(selection, exp.Subquery): 403 if not selection.output_name: 404 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 405 elif not isinstance(selection, exp.Alias) and not selection.is_star: 406 selection = alias( 407 selection, 408 alias=selection.output_name or f"_col_{i}", 409 ) 410 if aliased_column: 411 selection.set("alias", exp.to_identifier(aliased_column)) 412 413 new_selections.append(selection) 414 415 scope.expression.set("expressions", new_selections) 416 417 418def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 419 """Makes sure all identifiers that need to be quoted are quoted.""" 420 return expression.transform( 421 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 422 ) 423 424 425class Resolver: 426 """ 427 Helper for resolving columns. 428 429 This is a class so we can lazily load some things and easily share them across functions. 430 """ 431 432 def __init__(self, scope, schema, infer_schema: bool = True): 433 self.scope = scope 434 self.schema = schema 435 self._source_columns = None 436 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 437 self._all_columns = None 438 self._infer_schema = infer_schema 439 440 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 441 """ 442 Get the table for a column name. 443 444 Args: 445 column_name: The column name to find the table for. 446 Returns: 447 The table name if it can be found/inferred. 448 """ 449 if self._unambiguous_columns is None: 450 self._unambiguous_columns = self._get_unambiguous_columns( 451 self._get_all_source_columns() 452 ) 453 454 table_name = self._unambiguous_columns.get(column_name) 455 456 if not table_name and self._infer_schema: 457 sources_without_schema = tuple( 458 source 459 for source, columns in self._get_all_source_columns().items() 460 if not columns or "*" in columns 461 ) 462 if len(sources_without_schema) == 1: 463 table_name = sources_without_schema[0] 464 465 if table_name not in self.scope.selected_sources: 466 return exp.to_identifier(table_name) 467 468 node, _ = self.scope.selected_sources.get(table_name) 469 470 if isinstance(node, exp.Subqueryable): 471 while node and node.alias != table_name: 472 node = node.parent 473 474 node_alias = node.args.get("alias") 475 if node_alias: 476 return exp.to_identifier(node_alias.this) 477 478 return exp.to_identifier(table_name) 479 480 @property 481 def all_columns(self): 482 """All available columns of all sources in this scope""" 483 if self._all_columns is None: 484 self._all_columns = { 485 column for columns in self._get_all_source_columns().values() for column in columns 486 } 487 return self._all_columns 488 489 def get_source_columns(self, name, only_visible=False): 490 """Resolve the source columns for a given source `name`""" 491 if name not in self.scope.sources: 492 raise OptimizeError(f"Unknown table: {name}") 493 494 source = self.scope.sources[name] 495 496 # If referencing a table, return the columns from the schema 497 if isinstance(source, exp.Table): 498 return self.schema.column_names(source, only_visible) 499 500 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 501 return source.expression.alias_column_names 502 503 # Otherwise, if referencing another scope, return that scope's named selects 504 return source.expression.named_selects 505 506 def _get_all_source_columns(self): 507 if self._source_columns is None: 508 self._source_columns = { 509 k: self.get_source_columns(k) 510 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 511 } 512 return self._source_columns 513 514 def _get_unambiguous_columns(self, source_columns): 515 """ 516 Find all the unambiguous columns in sources. 517 518 Args: 519 source_columns (dict): Mapping of names to source columns 520 Returns: 521 dict: Mapping of column name to source name 522 """ 523 if not source_columns: 524 return {} 525 526 source_columns = list(source_columns.items()) 527 528 first_table, first_columns = source_columns[0] 529 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 530 all_columns = set(unambiguous_columns) 531 532 for table, columns in source_columns[1:]: 533 unique = self._find_unique_columns(columns) 534 ambiguous = set(all_columns).intersection(unique) 535 all_columns.update(columns) 536 for column in ambiguous: 537 unambiguous_columns.pop(column, None) 538 for column in unique.difference(ambiguous): 539 unambiguous_columns[column] = table 540 541 return unambiguous_columns 542 543 @staticmethod 544 def _find_unique_columns(columns): 545 """ 546 Find the unique columns in a list of columns. 547 548 Example: 549 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 550 ['a', 'c'] 551 552 This is necessary because duplicate column names are ambiguous. 553 """ 554 counts = {} 555 for column in columns: 556 counts[column] = counts.get(column, 0) + 1 557 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: expression to qualify 34 schema: Database schema 35 expand_alias_refs: whether or not to expand references to aliases 36 infer_schema: whether or not to infer the schema if missing 37 Returns: 38 sqlglot.Expression: qualified expression 39 """ 40 schema = ensure_schema(schema) 41 infer_schema = schema.empty if infer_schema is None else infer_schema 42 43 for scope in traverse_scope(expression): 44 resolver = Resolver(scope, schema, infer_schema=infer_schema) 45 _pop_table_column_aliases(scope.ctes) 46 _pop_table_column_aliases(scope.derived_tables) 47 using_column_tables = _expand_using(scope, resolver) 48 49 if schema.empty and expand_alias_refs: 50 _expand_alias_refs(scope, resolver) 51 52 _qualify_columns(scope, resolver) 53 54 if not schema.empty and expand_alias_refs: 55 _expand_alias_refs(scope, resolver) 56 57 if not isinstance(scope.expression, exp.UDTF): 58 _expand_stars(scope, resolver, using_column_tables) 59 _qualify_outputs(scope) 60 _expand_group_by(scope, resolver) 61 _expand_order_by(scope) 62 63 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: expression to qualify
- schema: Database schema
- expand_alias_refs: whether or not to expand references to aliases
- infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
66def validate_qualify_columns(expression): 67 """Raise an `OptimizeError` if any columns aren't qualified""" 68 unqualified_columns = [] 69 for scope in traverse_scope(expression): 70 if isinstance(scope.expression, exp.Select): 71 unqualified_columns.extend(scope.unqualified_columns) 72 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 73 column = scope.external_columns[0] 74 raise OptimizeError( 75 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 76 ) 77 78 if unqualified_columns: 79 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 80 return expression
Raise an OptimizeError
if any columns aren't qualified
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
419def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 420 """Makes sure all identifiers that need to be quoted are quoted.""" 421 return expression.transform( 422 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 423 )
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
426class Resolver: 427 """ 428 Helper for resolving columns. 429 430 This is a class so we can lazily load some things and easily share them across functions. 431 """ 432 433 def __init__(self, scope, schema, infer_schema: bool = True): 434 self.scope = scope 435 self.schema = schema 436 self._source_columns = None 437 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 438 self._all_columns = None 439 self._infer_schema = infer_schema 440 441 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 442 """ 443 Get the table for a column name. 444 445 Args: 446 column_name: The column name to find the table for. 447 Returns: 448 The table name if it can be found/inferred. 449 """ 450 if self._unambiguous_columns is None: 451 self._unambiguous_columns = self._get_unambiguous_columns( 452 self._get_all_source_columns() 453 ) 454 455 table_name = self._unambiguous_columns.get(column_name) 456 457 if not table_name and self._infer_schema: 458 sources_without_schema = tuple( 459 source 460 for source, columns in self._get_all_source_columns().items() 461 if not columns or "*" in columns 462 ) 463 if len(sources_without_schema) == 1: 464 table_name = sources_without_schema[0] 465 466 if table_name not in self.scope.selected_sources: 467 return exp.to_identifier(table_name) 468 469 node, _ = self.scope.selected_sources.get(table_name) 470 471 if isinstance(node, exp.Subqueryable): 472 while node and node.alias != table_name: 473 node = node.parent 474 475 node_alias = node.args.get("alias") 476 if node_alias: 477 return exp.to_identifier(node_alias.this) 478 479 return exp.to_identifier(table_name) 480 481 @property 482 def all_columns(self): 483 """All available columns of all sources in this scope""" 484 if self._all_columns is None: 485 self._all_columns = { 486 column for columns in self._get_all_source_columns().values() for column in columns 487 } 488 return self._all_columns 489 490 def get_source_columns(self, name, only_visible=False): 491 """Resolve the source columns for a given source `name`""" 492 if name not in self.scope.sources: 493 raise OptimizeError(f"Unknown table: {name}") 494 495 source = self.scope.sources[name] 496 497 # If referencing a table, return the columns from the schema 498 if isinstance(source, exp.Table): 499 return self.schema.column_names(source, only_visible) 500 501 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 502 return source.expression.alias_column_names 503 504 # Otherwise, if referencing another scope, return that scope's named selects 505 return source.expression.named_selects 506 507 def _get_all_source_columns(self): 508 if self._source_columns is None: 509 self._source_columns = { 510 k: self.get_source_columns(k) 511 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 512 } 513 return self._source_columns 514 515 def _get_unambiguous_columns(self, source_columns): 516 """ 517 Find all the unambiguous columns in sources. 518 519 Args: 520 source_columns (dict): Mapping of names to source columns 521 Returns: 522 dict: Mapping of column name to source name 523 """ 524 if not source_columns: 525 return {} 526 527 source_columns = list(source_columns.items()) 528 529 first_table, first_columns = source_columns[0] 530 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 531 all_columns = set(unambiguous_columns) 532 533 for table, columns in source_columns[1:]: 534 unique = self._find_unique_columns(columns) 535 ambiguous = set(all_columns).intersection(unique) 536 all_columns.update(columns) 537 for column in ambiguous: 538 unambiguous_columns.pop(column, None) 539 for column in unique.difference(ambiguous): 540 unambiguous_columns[column] = table 541 542 return unambiguous_columns 543 544 @staticmethod 545 def _find_unique_columns(columns): 546 """ 547 Find the unique columns in a list of columns. 548 549 Example: 550 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 551 ['a', 'c'] 552 553 This is necessary because duplicate column names are ambiguous. 554 """ 555 counts = {} 556 for column in columns: 557 counts[column] = counts.get(column, 0) + 1 558 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
441 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 442 """ 443 Get the table for a column name. 444 445 Args: 446 column_name: The column name to find the table for. 447 Returns: 448 The table name if it can be found/inferred. 449 """ 450 if self._unambiguous_columns is None: 451 self._unambiguous_columns = self._get_unambiguous_columns( 452 self._get_all_source_columns() 453 ) 454 455 table_name = self._unambiguous_columns.get(column_name) 456 457 if not table_name and self._infer_schema: 458 sources_without_schema = tuple( 459 source 460 for source, columns in self._get_all_source_columns().items() 461 if not columns or "*" in columns 462 ) 463 if len(sources_without_schema) == 1: 464 table_name = sources_without_schema[0] 465 466 if table_name not in self.scope.selected_sources: 467 return exp.to_identifier(table_name) 468 469 node, _ = self.scope.selected_sources.get(table_name) 470 471 if isinstance(node, exp.Subqueryable): 472 while node and node.alias != table_name: 473 node = node.parent 474 475 node_alias = node.args.get("alias") 476 if node_alias: 477 return exp.to_identifier(node_alias.this) 478 479 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
490 def get_source_columns(self, name, only_visible=False): 491 """Resolve the source columns for a given source `name`""" 492 if name not in self.scope.sources: 493 raise OptimizeError(f"Unknown table: {name}") 494 495 source = self.scope.sources[name] 496 497 # If referencing a table, return the columns from the schema 498 if isinstance(source, exp.Table): 499 return self.schema.column_names(source, only_visible) 500 501 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 502 return source.expression.alias_column_names 503 504 # Otherwise, if referencing another scope, return that scope's named selects 505 return source.expression.named_selects
Resolve the source columns for a given source name