sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.errors import OptimizeError 8from sqlglot.helper import seq_get 9from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 10from sqlglot.schema import Schema, ensure_schema 11 12 13def qualify_columns( 14 expression: exp.Expression, 15 schema: dict | Schema, 16 expand_alias_refs: bool = True, 17 infer_schema: t.Optional[bool] = None, 18) -> exp.Expression: 19 """ 20 Rewrite sqlglot AST to have fully qualified columns. 21 22 Example: 23 >>> import sqlglot 24 >>> schema = {"tbl": {"col": "INT"}} 25 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 26 >>> qualify_columns(expression, schema).sql() 27 'SELECT tbl.col AS col FROM tbl' 28 29 Args: 30 expression: expression to qualify 31 schema: Database schema 32 expand_alias_refs: whether or not to expand references to aliases 33 infer_schema: whether or not to infer the schema if missing 34 Returns: 35 sqlglot.Expression: qualified expression 36 """ 37 schema = ensure_schema(schema) 38 infer_schema = schema.empty if infer_schema is None else infer_schema 39 40 for scope in traverse_scope(expression): 41 resolver = Resolver(scope, schema, infer_schema=infer_schema) 42 _pop_table_column_aliases(scope.ctes) 43 _pop_table_column_aliases(scope.derived_tables) 44 using_column_tables = _expand_using(scope, resolver) 45 46 if schema.empty and expand_alias_refs: 47 _expand_alias_refs(scope, resolver) 48 49 _qualify_columns(scope, resolver) 50 51 if not schema.empty and expand_alias_refs: 52 _expand_alias_refs(scope, resolver) 53 54 if not isinstance(scope.expression, exp.UDTF): 55 _expand_stars(scope, resolver, using_column_tables) 56 _qualify_outputs(scope) 57 _expand_group_by(scope, resolver) 58 _expand_order_by(scope) 59 60 return expression 61 62 63def validate_qualify_columns(expression): 64 """Raise an `OptimizeError` if any columns aren't qualified""" 65 unqualified_columns = [] 66 for scope in traverse_scope(expression): 67 if isinstance(scope.expression, exp.Select): 68 unqualified_columns.extend(scope.unqualified_columns) 69 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 70 column = scope.external_columns[0] 71 raise OptimizeError( 72 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 73 ) 74 75 if unqualified_columns: 76 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 77 return expression 78 79 80def _pop_table_column_aliases(derived_tables): 81 """ 82 Remove table column aliases. 83 84 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 85 """ 86 for derived_table in derived_tables: 87 table_alias = derived_table.args.get("alias") 88 if table_alias: 89 table_alias.args.pop("columns", None) 90 91 92def _expand_using(scope, resolver): 93 joins = list(scope.find_all(exp.Join)) 94 names = {join.this.alias for join in joins} 95 ordered = [key for key in scope.selected_sources if key not in names] 96 97 # Mapping of automatically joined column names to an ordered set of source names (dict). 98 column_tables = {} 99 100 for join in joins: 101 using = join.args.get("using") 102 103 if not using: 104 continue 105 106 join_table = join.this.alias_or_name 107 108 columns = {} 109 110 for k in scope.selected_sources: 111 if k in ordered: 112 for column in resolver.get_source_columns(k): 113 if column not in columns: 114 columns[column] = k 115 116 source_table = ordered[-1] 117 ordered.append(join_table) 118 join_columns = resolver.get_source_columns(join_table) 119 conditions = [] 120 121 for identifier in using: 122 identifier = identifier.name 123 table = columns.get(identifier) 124 125 if not table or identifier not in join_columns: 126 if columns and join_columns: 127 raise OptimizeError(f"Cannot automatically join: {identifier}") 128 129 table = table or source_table 130 conditions.append( 131 exp.condition( 132 exp.EQ( 133 this=exp.column(identifier, table=table), 134 expression=exp.column(identifier, table=join_table), 135 ) 136 ) 137 ) 138 139 # Set all values in the dict to None, because we only care about the key ordering 140 tables = column_tables.setdefault(identifier, {}) 141 if table not in tables: 142 tables[table] = None 143 if join_table not in tables: 144 tables[join_table] = None 145 146 join.args.pop("using") 147 join.set("on", exp.and_(*conditions, copy=False)) 148 149 if column_tables: 150 for column in scope.columns: 151 if not column.table and column.name in column_tables: 152 tables = column_tables[column.name] 153 coalesce = [exp.column(column.name, table=table) for table in tables] 154 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 155 156 # Ensure selects keep their output name 157 if isinstance(column.parent, exp.Select): 158 replacement = alias(replacement, alias=column.name, copy=False) 159 160 scope.replace(column, replacement) 161 162 return column_tables 163 164 165def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 166 expression = scope.expression 167 168 if not isinstance(expression, exp.Select): 169 return 170 171 alias_to_expression: t.Dict[str, exp.Expression] = {} 172 173 def replace_columns( 174 node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False 175 ): 176 if not node: 177 return 178 179 for column, *_ in walk_in_scope(node): 180 if not isinstance(column, exp.Column): 181 continue 182 table = resolver.get_table(column.name) if resolve_agg and not column.table else None 183 if table and column.find_ancestor(exp.AggFunc): 184 column.set("table", table) 185 elif expand and not column.table and column.name in alias_to_expression: 186 column.replace(alias_to_expression[column.name].copy()) 187 188 for projection in scope.selects: 189 replace_columns(projection) 190 191 if isinstance(projection, exp.Alias): 192 alias_to_expression[projection.alias] = projection.this 193 194 replace_columns(expression.args.get("where")) 195 replace_columns(expression.args.get("group")) 196 replace_columns(expression.args.get("having"), resolve_agg=True) 197 replace_columns(expression.args.get("qualify"), resolve_agg=True) 198 replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) 199 scope.clear_cache() 200 201 202def _expand_group_by(scope, resolver): 203 group = scope.expression.args.get("group") 204 if not group: 205 return 206 207 group.set("expressions", _expand_positional_references(scope, group.expressions)) 208 scope.expression.set("group", group) 209 210 211def _expand_order_by(scope): 212 order = scope.expression.args.get("order") 213 if not order: 214 return 215 216 ordereds = order.expressions 217 for ordered, new_expression in zip( 218 ordereds, 219 _expand_positional_references(scope, (o.this for o in ordereds)), 220 ): 221 ordered.set("this", new_expression) 222 223 224def _expand_positional_references(scope, expressions): 225 new_nodes = [] 226 for node in expressions: 227 if node.is_int: 228 try: 229 select = scope.selects[int(node.name) - 1] 230 except IndexError: 231 raise OptimizeError(f"Unknown output column: {node.name}") 232 if isinstance(select, exp.Alias): 233 select = select.this 234 new_nodes.append(select.copy()) 235 scope.clear_cache() 236 else: 237 new_nodes.append(node) 238 239 return new_nodes 240 241 242def _qualify_columns(scope, resolver): 243 """Disambiguate columns, ensuring each column specifies a source""" 244 for column in scope.columns: 245 column_table = column.table 246 column_name = column.name 247 248 if column_table and column_table in scope.sources: 249 source_columns = resolver.get_source_columns(column_table) 250 if source_columns and column_name not in source_columns and "*" not in source_columns: 251 raise OptimizeError(f"Unknown column: {column_name}") 252 253 if not column_table: 254 if scope.pivots and not column.find_ancestor(exp.Pivot): 255 # If the column is under the Pivot expression, we need to qualify it 256 # using the name of the pivoted source instead of the pivot's alias 257 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 258 continue 259 260 column_table = resolver.get_table(column_name) 261 262 # column_table can be a '' because bigquery unnest has no table alias 263 if column_table: 264 column.set("table", column_table) 265 elif column_table not in scope.sources and ( 266 not scope.parent or column_table not in scope.parent.sources 267 ): 268 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 269 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 270 271 root, *parts = column.parts 272 273 if root.name in scope.sources: 274 # struct is already qualified, but we still need to change the AST representation 275 column_table = root 276 root, *parts = parts 277 else: 278 column_table = resolver.get_table(root.name) 279 280 if column_table: 281 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 282 283 for pivot in scope.pivots: 284 for column in pivot.find_all(exp.Column): 285 if not column.table and column.name in resolver.all_columns: 286 column_table = resolver.get_table(column.name) 287 if column_table: 288 column.set("table", column_table) 289 290 291def _expand_stars(scope, resolver, using_column_tables): 292 """Expand stars to lists of column selections""" 293 294 new_selections = [] 295 except_columns = {} 296 replace_columns = {} 297 coalesced_columns = set() 298 299 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 300 pivot_columns = None 301 pivot_output_columns = None 302 pivot = seq_get(scope.pivots, 0) 303 304 has_pivoted_source = pivot and not pivot.args.get("unpivot") 305 if has_pivoted_source: 306 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 307 308 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 309 if not pivot_output_columns: 310 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 311 312 for expression in scope.selects: 313 if isinstance(expression, exp.Star): 314 tables = list(scope.selected_sources) 315 _add_except_columns(expression, tables, except_columns) 316 _add_replace_columns(expression, tables, replace_columns) 317 elif expression.is_star: 318 tables = [expression.table] 319 _add_except_columns(expression.this, tables, except_columns) 320 _add_replace_columns(expression.this, tables, replace_columns) 321 else: 322 new_selections.append(expression) 323 continue 324 325 for table in tables: 326 if table not in scope.sources: 327 raise OptimizeError(f"Unknown table: {table}") 328 329 columns = resolver.get_source_columns(table, only_visible=True) 330 331 if columns and "*" not in columns: 332 if has_pivoted_source: 333 implicit_columns = [col for col in columns if col not in pivot_columns] 334 new_selections.extend( 335 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 336 for name in implicit_columns + pivot_output_columns 337 ) 338 continue 339 340 table_id = id(table) 341 for name in columns: 342 if name in using_column_tables and table in using_column_tables[name]: 343 if name in coalesced_columns: 344 continue 345 346 coalesced_columns.add(name) 347 tables = using_column_tables[name] 348 coalesce = [exp.column(name, table=table) for table in tables] 349 350 new_selections.append( 351 alias( 352 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 353 alias=name, 354 copy=False, 355 ) 356 ) 357 elif name not in except_columns.get(table_id, set()): 358 alias_ = replace_columns.get(table_id, {}).get(name, name) 359 column = exp.column(name, table=table) 360 new_selections.append( 361 alias(column, alias_, copy=False) if alias_ != name else column 362 ) 363 else: 364 return 365 366 scope.expression.set("expressions", new_selections) 367 368 369def _add_except_columns(expression, tables, except_columns): 370 except_ = expression.args.get("except") 371 372 if not except_: 373 return 374 375 columns = {e.name for e in except_} 376 377 for table in tables: 378 except_columns[id(table)] = columns 379 380 381def _add_replace_columns(expression, tables, replace_columns): 382 replace = expression.args.get("replace") 383 384 if not replace: 385 return 386 387 columns = {e.this.name: e.alias for e in replace} 388 389 for table in tables: 390 replace_columns[id(table)] = columns 391 392 393def _qualify_outputs(scope): 394 """Ensure all output columns are aliased""" 395 new_selections = [] 396 397 for i, (selection, aliased_column) in enumerate( 398 itertools.zip_longest(scope.selects, scope.outer_column_list) 399 ): 400 if isinstance(selection, exp.Subquery): 401 if not selection.output_name: 402 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 403 elif not isinstance(selection, exp.Alias) and not selection.is_star: 404 selection = alias( 405 selection, 406 alias=selection.output_name or f"_col_{i}", 407 quoted=True 408 if isinstance(selection, exp.Column) and selection.this.quoted 409 else None, 410 ) 411 if aliased_column: 412 selection.set("alias", exp.to_identifier(aliased_column)) 413 414 new_selections.append(selection) 415 416 scope.expression.set("expressions", new_selections) 417 418 419class Resolver: 420 """ 421 Helper for resolving columns. 422 423 This is a class so we can lazily load some things and easily share them across functions. 424 """ 425 426 def __init__(self, scope, schema, infer_schema: bool = True): 427 self.scope = scope 428 self.schema = schema 429 self._source_columns = None 430 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 431 self._all_columns = None 432 self._infer_schema = infer_schema 433 434 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 435 """ 436 Get the table for a column name. 437 438 Args: 439 column_name: The column name to find the table for. 440 Returns: 441 The table name if it can be found/inferred. 442 """ 443 if self._unambiguous_columns is None: 444 self._unambiguous_columns = self._get_unambiguous_columns( 445 self._get_all_source_columns() 446 ) 447 448 table_name = self._unambiguous_columns.get(column_name) 449 450 if not table_name and self._infer_schema: 451 sources_without_schema = tuple( 452 source 453 for source, columns in self._get_all_source_columns().items() 454 if not columns or "*" in columns 455 ) 456 if len(sources_without_schema) == 1: 457 table_name = sources_without_schema[0] 458 459 if table_name not in self.scope.selected_sources: 460 return exp.to_identifier(table_name) 461 462 node, _ = self.scope.selected_sources.get(table_name) 463 464 if isinstance(node, exp.Subqueryable): 465 while node and node.alias != table_name: 466 node = node.parent 467 468 node_alias = node.args.get("alias") 469 if node_alias: 470 return exp.to_identifier(node_alias.this) 471 472 return exp.to_identifier( 473 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 474 ) 475 476 @property 477 def all_columns(self): 478 """All available columns of all sources in this scope""" 479 if self._all_columns is None: 480 self._all_columns = { 481 column for columns in self._get_all_source_columns().values() for column in columns 482 } 483 return self._all_columns 484 485 def get_source_columns(self, name, only_visible=False): 486 """Resolve the source columns for a given source `name`""" 487 if name not in self.scope.sources: 488 raise OptimizeError(f"Unknown table: {name}") 489 490 source = self.scope.sources[name] 491 492 # If referencing a table, return the columns from the schema 493 if isinstance(source, exp.Table): 494 return self.schema.column_names(source, only_visible) 495 496 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 497 return source.expression.alias_column_names 498 499 # Otherwise, if referencing another scope, return that scope's named selects 500 return source.expression.named_selects 501 502 def _get_all_source_columns(self): 503 if self._source_columns is None: 504 self._source_columns = { 505 k: self.get_source_columns(k) 506 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 507 } 508 return self._source_columns 509 510 def _get_unambiguous_columns(self, source_columns): 511 """ 512 Find all the unambiguous columns in sources. 513 514 Args: 515 source_columns (dict): Mapping of names to source columns 516 Returns: 517 dict: Mapping of column name to source name 518 """ 519 if not source_columns: 520 return {} 521 522 source_columns = list(source_columns.items()) 523 524 first_table, first_columns = source_columns[0] 525 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 526 all_columns = set(unambiguous_columns) 527 528 for table, columns in source_columns[1:]: 529 unique = self._find_unique_columns(columns) 530 ambiguous = set(all_columns).intersection(unique) 531 all_columns.update(columns) 532 for column in ambiguous: 533 unambiguous_columns.pop(column, None) 534 for column in unique.difference(ambiguous): 535 unambiguous_columns[column] = table 536 537 return unambiguous_columns 538 539 @staticmethod 540 def _find_unique_columns(columns): 541 """ 542 Find the unique columns in a list of columns. 543 544 Example: 545 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 546 ['a', 'c'] 547 548 This is necessary because duplicate column names are ambiguous. 549 """ 550 counts = {} 551 for column in columns: 552 counts[column] = counts.get(column, 0) + 1 553 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: dict | sqlglot.schema.Schema, expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
14def qualify_columns( 15 expression: exp.Expression, 16 schema: dict | Schema, 17 expand_alias_refs: bool = True, 18 infer_schema: t.Optional[bool] = None, 19) -> exp.Expression: 20 """ 21 Rewrite sqlglot AST to have fully qualified columns. 22 23 Example: 24 >>> import sqlglot 25 >>> schema = {"tbl": {"col": "INT"}} 26 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 27 >>> qualify_columns(expression, schema).sql() 28 'SELECT tbl.col AS col FROM tbl' 29 30 Args: 31 expression: expression to qualify 32 schema: Database schema 33 expand_alias_refs: whether or not to expand references to aliases 34 infer_schema: whether or not to infer the schema if missing 35 Returns: 36 sqlglot.Expression: qualified expression 37 """ 38 schema = ensure_schema(schema) 39 infer_schema = schema.empty if infer_schema is None else infer_schema 40 41 for scope in traverse_scope(expression): 42 resolver = Resolver(scope, schema, infer_schema=infer_schema) 43 _pop_table_column_aliases(scope.ctes) 44 _pop_table_column_aliases(scope.derived_tables) 45 using_column_tables = _expand_using(scope, resolver) 46 47 if schema.empty and expand_alias_refs: 48 _expand_alias_refs(scope, resolver) 49 50 _qualify_columns(scope, resolver) 51 52 if not schema.empty and expand_alias_refs: 53 _expand_alias_refs(scope, resolver) 54 55 if not isinstance(scope.expression, exp.UDTF): 56 _expand_stars(scope, resolver, using_column_tables) 57 _qualify_outputs(scope) 58 _expand_group_by(scope, resolver) 59 _expand_order_by(scope) 60 61 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: expression to qualify
- schema: Database schema
- expand_alias_refs: whether or not to expand references to aliases
- infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
64def validate_qualify_columns(expression): 65 """Raise an `OptimizeError` if any columns aren't qualified""" 66 unqualified_columns = [] 67 for scope in traverse_scope(expression): 68 if isinstance(scope.expression, exp.Select): 69 unqualified_columns.extend(scope.unqualified_columns) 70 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 71 column = scope.external_columns[0] 72 raise OptimizeError( 73 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 74 ) 75 76 if unqualified_columns: 77 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 78 return expression
Raise an OptimizeError
if any columns aren't qualified
class
Resolver:
420class Resolver: 421 """ 422 Helper for resolving columns. 423 424 This is a class so we can lazily load some things and easily share them across functions. 425 """ 426 427 def __init__(self, scope, schema, infer_schema: bool = True): 428 self.scope = scope 429 self.schema = schema 430 self._source_columns = None 431 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 432 self._all_columns = None 433 self._infer_schema = infer_schema 434 435 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 436 """ 437 Get the table for a column name. 438 439 Args: 440 column_name: The column name to find the table for. 441 Returns: 442 The table name if it can be found/inferred. 443 """ 444 if self._unambiguous_columns is None: 445 self._unambiguous_columns = self._get_unambiguous_columns( 446 self._get_all_source_columns() 447 ) 448 449 table_name = self._unambiguous_columns.get(column_name) 450 451 if not table_name and self._infer_schema: 452 sources_without_schema = tuple( 453 source 454 for source, columns in self._get_all_source_columns().items() 455 if not columns or "*" in columns 456 ) 457 if len(sources_without_schema) == 1: 458 table_name = sources_without_schema[0] 459 460 if table_name not in self.scope.selected_sources: 461 return exp.to_identifier(table_name) 462 463 node, _ = self.scope.selected_sources.get(table_name) 464 465 if isinstance(node, exp.Subqueryable): 466 while node and node.alias != table_name: 467 node = node.parent 468 469 node_alias = node.args.get("alias") 470 if node_alias: 471 return exp.to_identifier(node_alias.this) 472 473 return exp.to_identifier( 474 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 475 ) 476 477 @property 478 def all_columns(self): 479 """All available columns of all sources in this scope""" 480 if self._all_columns is None: 481 self._all_columns = { 482 column for columns in self._get_all_source_columns().values() for column in columns 483 } 484 return self._all_columns 485 486 def get_source_columns(self, name, only_visible=False): 487 """Resolve the source columns for a given source `name`""" 488 if name not in self.scope.sources: 489 raise OptimizeError(f"Unknown table: {name}") 490 491 source = self.scope.sources[name] 492 493 # If referencing a table, return the columns from the schema 494 if isinstance(source, exp.Table): 495 return self.schema.column_names(source, only_visible) 496 497 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 498 return source.expression.alias_column_names 499 500 # Otherwise, if referencing another scope, return that scope's named selects 501 return source.expression.named_selects 502 503 def _get_all_source_columns(self): 504 if self._source_columns is None: 505 self._source_columns = { 506 k: self.get_source_columns(k) 507 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 508 } 509 return self._source_columns 510 511 def _get_unambiguous_columns(self, source_columns): 512 """ 513 Find all the unambiguous columns in sources. 514 515 Args: 516 source_columns (dict): Mapping of names to source columns 517 Returns: 518 dict: Mapping of column name to source name 519 """ 520 if not source_columns: 521 return {} 522 523 source_columns = list(source_columns.items()) 524 525 first_table, first_columns = source_columns[0] 526 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 527 all_columns = set(unambiguous_columns) 528 529 for table, columns in source_columns[1:]: 530 unique = self._find_unique_columns(columns) 531 ambiguous = set(all_columns).intersection(unique) 532 all_columns.update(columns) 533 for column in ambiguous: 534 unambiguous_columns.pop(column, None) 535 for column in unique.difference(ambiguous): 536 unambiguous_columns[column] = table 537 538 return unambiguous_columns 539 540 @staticmethod 541 def _find_unique_columns(columns): 542 """ 543 Find the unique columns in a list of columns. 544 545 Example: 546 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 547 ['a', 'c'] 548 549 This is necessary because duplicate column names are ambiguous. 550 """ 551 counts = {} 552 for column in columns: 553 counts[column] = counts.get(column, 0) + 1 554 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
435 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 436 """ 437 Get the table for a column name. 438 439 Args: 440 column_name: The column name to find the table for. 441 Returns: 442 The table name if it can be found/inferred. 443 """ 444 if self._unambiguous_columns is None: 445 self._unambiguous_columns = self._get_unambiguous_columns( 446 self._get_all_source_columns() 447 ) 448 449 table_name = self._unambiguous_columns.get(column_name) 450 451 if not table_name and self._infer_schema: 452 sources_without_schema = tuple( 453 source 454 for source, columns in self._get_all_source_columns().items() 455 if not columns or "*" in columns 456 ) 457 if len(sources_without_schema) == 1: 458 table_name = sources_without_schema[0] 459 460 if table_name not in self.scope.selected_sources: 461 return exp.to_identifier(table_name) 462 463 node, _ = self.scope.selected_sources.get(table_name) 464 465 if isinstance(node, exp.Subqueryable): 466 while node and node.alias != table_name: 467 node = node.parent 468 469 node_alias = node.args.get("alias") 470 if node_alias: 471 return exp.to_identifier(node_alias.this) 472 473 return exp.to_identifier( 474 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 475 )
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
486 def get_source_columns(self, name, only_visible=False): 487 """Resolve the source columns for a given source `name`""" 488 if name not in self.scope.sources: 489 raise OptimizeError(f"Unknown table: {name}") 490 491 source = self.scope.sources[name] 492 493 # If referencing a table, return the columns from the schema 494 if isinstance(source, exp.Table): 495 return self.schema.column_names(source, only_visible) 496 497 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 498 return source.expression.alias_column_names 499 500 # Otherwise, if referencing another scope, return that scope's named selects 501 return source.expression.named_selects
Resolve the source columns for a given source name