sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 using_column_tables = _expand_using(scope, resolver) 34 _qualify_columns(scope, resolver) 35 if not isinstance(scope.expression, exp.UDTF): 36 _expand_stars(scope, resolver, using_column_tables) 37 _qualify_outputs(scope) 38 _expand_alias_refs(scope, resolver) 39 _expand_group_by(scope, resolver) 40 _expand_order_by(scope) 41 42 return expression 43 44 45def validate_qualify_columns(expression): 46 """Raise an `OptimizeError` if any columns aren't qualified""" 47 unqualified_columns = [] 48 for scope in traverse_scope(expression): 49 if isinstance(scope.expression, exp.Select): 50 unqualified_columns.extend(scope.unqualified_columns) 51 if scope.external_columns and not scope.is_correlated_subquery: 52 column = scope.external_columns[0] 53 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 54 55 if unqualified_columns: 56 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 57 return expression 58 59 60def _pop_table_column_aliases(derived_tables): 61 """ 62 Remove table column aliases. 63 64 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 65 """ 66 for derived_table in derived_tables: 67 table_alias = derived_table.args.get("alias") 68 if table_alias: 69 table_alias.args.pop("columns", None) 70 71 72def _expand_using(scope, resolver): 73 joins = list(scope.find_all(exp.Join)) 74 names = {join.this.alias for join in joins} 75 ordered = [key for key in scope.selected_sources if key not in names] 76 77 # Mapping of automatically joined column names to an ordered set of source names (dict). 78 column_tables = {} 79 80 for join in joins: 81 using = join.args.get("using") 82 83 if not using: 84 continue 85 86 join_table = join.this.alias_or_name 87 88 columns = {} 89 90 for k in scope.selected_sources: 91 if k in ordered: 92 for column in resolver.get_source_columns(k): 93 if column not in columns: 94 columns[column] = k 95 96 ordered.append(join_table) 97 join_columns = resolver.get_source_columns(join_table) 98 conditions = [] 99 100 for identifier in using: 101 identifier = identifier.name 102 table = columns.get(identifier) 103 104 if not table or identifier not in join_columns: 105 raise OptimizeError(f"Cannot automatically join: {identifier}") 106 107 conditions.append( 108 exp.condition( 109 exp.EQ( 110 this=exp.column(identifier, table=table), 111 expression=exp.column(identifier, table=join_table), 112 ) 113 ) 114 ) 115 116 # Set all values in the dict to None, because we only care about the key ordering 117 tables = column_tables.setdefault(identifier, {}) 118 if table not in tables: 119 tables[table] = None 120 if join_table not in tables: 121 tables[join_table] = None 122 123 join.args.pop("using") 124 join.set("on", exp.and_(*conditions)) 125 126 if column_tables: 127 for column in scope.columns: 128 if not column.table and column.name in column_tables: 129 tables = column_tables[column.name] 130 coalesce = [exp.column(column.name, table=table) for table in tables] 131 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 132 133 # Ensure selects keep their output name 134 if isinstance(column.parent, exp.Select): 135 replacement = exp.alias_(replacement, alias=column.name) 136 137 scope.replace(column, replacement) 138 139 return column_tables 140 141 142def _expand_alias_refs(scope, resolver): 143 selects = {} 144 145 # Replace references to select aliases 146 def transform(node, *_): 147 if isinstance(node, exp.Column) and not node.table: 148 table = resolver.get_table(node.name) 149 150 # Source columns get priority over select aliases 151 if table: 152 node.set("table", table) 153 return node 154 155 if not selects: 156 for s in scope.selects: 157 selects[s.alias_or_name] = s 158 select = selects.get(node.name) 159 160 if select: 161 scope.clear_cache() 162 if isinstance(select, exp.Alias): 163 select = select.this 164 return select.copy() 165 166 return node 167 168 where = scope.expression.args.get("where") 169 if where: 170 where.transform(transform, copy=False) 171 172 group = scope.expression.args.get("group") 173 if group: 174 group.transform(transform, copy=False) 175 176 177def _expand_group_by(scope, resolver): 178 group = scope.expression.args.get("group") 179 if not group: 180 return 181 182 group.set("expressions", _expand_positional_references(scope, group.expressions)) 183 scope.expression.set("group", group) 184 185 186def _expand_order_by(scope): 187 order = scope.expression.args.get("order") 188 if not order: 189 return 190 191 ordereds = order.expressions 192 for ordered, new_expression in zip( 193 ordereds, 194 _expand_positional_references(scope, (o.this for o in ordereds)), 195 ): 196 ordered.set("this", new_expression) 197 198 199def _expand_positional_references(scope, expressions): 200 new_nodes = [] 201 for node in expressions: 202 if node.is_int: 203 try: 204 select = scope.selects[int(node.name) - 1] 205 except IndexError: 206 raise OptimizeError(f"Unknown output column: {node.name}") 207 if isinstance(select, exp.Alias): 208 select = select.this 209 new_nodes.append(select.copy()) 210 scope.clear_cache() 211 else: 212 new_nodes.append(node) 213 214 return new_nodes 215 216 217def _qualify_columns(scope, resolver): 218 """Disambiguate columns, ensuring each column specifies a source""" 219 for column in scope.columns: 220 column_table = column.table 221 column_name = column.name 222 223 if column_table and column_table in scope.sources: 224 source_columns = resolver.get_source_columns(column_table) 225 if source_columns and column_name not in source_columns and "*" not in source_columns: 226 raise OptimizeError(f"Unknown column: {column_name}") 227 228 if not column_table: 229 column_table = resolver.get_table(column_name) 230 231 # column_table can be a '' because bigquery unnest has no table alias 232 if column_table: 233 column.set("table", column_table) 234 elif column_table not in scope.sources: 235 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 236 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 237 238 root, *parts = column.parts 239 240 if root.name in scope.sources: 241 # struct is already qualified, but we still need to change the AST representation 242 column_table = root 243 root, *parts = parts 244 else: 245 column_table = resolver.get_table(root.name) 246 247 if column_table: 248 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 249 250 columns_missing_from_scope = [] 251 252 # Determine whether each reference in the order by clause is to a column or an alias. 253 order = scope.expression.args.get("order") 254 255 if order: 256 for ordered in order.expressions: 257 for column in ordered.find_all(exp.Column): 258 if ( 259 not column.table 260 and column.parent is not ordered 261 and column.name in resolver.all_columns 262 ): 263 columns_missing_from_scope.append(column) 264 265 # Determine whether each reference in the having clause is to a column or an alias. 266 having = scope.expression.args.get("having") 267 268 if having: 269 for column in having.find_all(exp.Column): 270 if ( 271 not column.table 272 and column.find_ancestor(exp.AggFunc) 273 and column.name in resolver.all_columns 274 ): 275 columns_missing_from_scope.append(column) 276 277 for column in columns_missing_from_scope: 278 column_table = resolver.get_table(column.name) 279 280 if column_table: 281 column.set("table", column_table) 282 283 284def _expand_stars(scope, resolver, using_column_tables): 285 """Expand stars to lists of column selections""" 286 287 new_selections = [] 288 except_columns = {} 289 replace_columns = {} 290 coalesced_columns = set() 291 292 for expression in scope.selects: 293 if isinstance(expression, exp.Star): 294 tables = list(scope.selected_sources) 295 _add_except_columns(expression, tables, except_columns) 296 _add_replace_columns(expression, tables, replace_columns) 297 elif expression.is_star: 298 tables = [expression.table] 299 _add_except_columns(expression.this, tables, except_columns) 300 _add_replace_columns(expression.this, tables, replace_columns) 301 else: 302 new_selections.append(expression) 303 continue 304 305 for table in tables: 306 if table not in scope.sources: 307 raise OptimizeError(f"Unknown table: {table}") 308 columns = resolver.get_source_columns(table, only_visible=True) 309 310 if columns and "*" not in columns: 311 table_id = id(table) 312 for name in columns: 313 if name in using_column_tables and table in using_column_tables[name]: 314 if name in coalesced_columns: 315 continue 316 317 coalesced_columns.add(name) 318 tables = using_column_tables[name] 319 coalesce = [exp.column(name, table=table) for table in tables] 320 321 new_selections.append( 322 exp.alias_( 323 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name 324 ) 325 ) 326 elif name not in except_columns.get(table_id, set()): 327 alias_ = replace_columns.get(table_id, {}).get(name, name) 328 column = exp.column(name, table) 329 new_selections.append(alias(column, alias_) if alias_ != name else column) 330 else: 331 return 332 scope.expression.set("expressions", new_selections) 333 334 335def _add_except_columns(expression, tables, except_columns): 336 except_ = expression.args.get("except") 337 338 if not except_: 339 return 340 341 columns = {e.name for e in except_} 342 343 for table in tables: 344 except_columns[id(table)] = columns 345 346 347def _add_replace_columns(expression, tables, replace_columns): 348 replace = expression.args.get("replace") 349 350 if not replace: 351 return 352 353 columns = {e.this.name: e.alias for e in replace} 354 355 for table in tables: 356 replace_columns[id(table)] = columns 357 358 359def _qualify_outputs(scope): 360 """Ensure all output columns are aliased""" 361 new_selections = [] 362 363 for i, (selection, aliased_column) in enumerate( 364 itertools.zip_longest(scope.selects, scope.outer_column_list) 365 ): 366 if isinstance(selection, exp.Subquery): 367 if not selection.output_name: 368 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 369 elif not isinstance(selection, exp.Alias) and not selection.is_star: 370 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 371 alias_.set("this", selection) 372 selection = alias_ 373 374 if aliased_column: 375 selection.set("alias", exp.to_identifier(aliased_column)) 376 377 new_selections.append(selection) 378 379 scope.expression.set("expressions", new_selections) 380 381 382class Resolver: 383 """ 384 Helper for resolving columns. 385 386 This is a class so we can lazily load some things and easily share them across functions. 387 """ 388 389 def __init__(self, scope, schema): 390 self.scope = scope 391 self.schema = schema 392 self._source_columns = None 393 self._unambiguous_columns = None 394 self._all_columns = None 395 396 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 397 """ 398 Get the table for a column name. 399 400 Args: 401 column_name: The column name to find the table for. 402 Returns: 403 The table name if it can be found/inferred. 404 """ 405 if self._unambiguous_columns is None: 406 self._unambiguous_columns = self._get_unambiguous_columns( 407 self._get_all_source_columns() 408 ) 409 410 table_name = self._unambiguous_columns.get(column_name) 411 412 if not table_name: 413 sources_without_schema = tuple( 414 source 415 for source, columns in self._get_all_source_columns().items() 416 if not columns or "*" in columns 417 ) 418 if len(sources_without_schema) == 1: 419 table_name = sources_without_schema[0] 420 421 if table_name not in self.scope.selected_sources: 422 return exp.to_identifier(table_name) 423 424 node, _ = self.scope.selected_sources.get(table_name) 425 426 if isinstance(node, exp.Subqueryable): 427 while node and node.alias != table_name: 428 node = node.parent 429 430 node_alias = node.args.get("alias") 431 if node_alias: 432 return node_alias.this 433 434 return exp.to_identifier( 435 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 436 ) 437 438 @property 439 def all_columns(self): 440 """All available columns of all sources in this scope""" 441 if self._all_columns is None: 442 self._all_columns = { 443 column for columns in self._get_all_source_columns().values() for column in columns 444 } 445 return self._all_columns 446 447 def get_source_columns(self, name, only_visible=False): 448 """Resolve the source columns for a given source `name`""" 449 if name not in self.scope.sources: 450 raise OptimizeError(f"Unknown table: {name}") 451 452 source = self.scope.sources[name] 453 454 # If referencing a table, return the columns from the schema 455 if isinstance(source, exp.Table): 456 return self.schema.column_names(source, only_visible) 457 458 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 459 return source.expression.alias_column_names 460 461 # Otherwise, if referencing another scope, return that scope's named selects 462 return source.expression.named_selects 463 464 def _get_all_source_columns(self): 465 if self._source_columns is None: 466 self._source_columns = { 467 k: self.get_source_columns(k) 468 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 469 } 470 return self._source_columns 471 472 def _get_unambiguous_columns(self, source_columns): 473 """ 474 Find all the unambiguous columns in sources. 475 476 Args: 477 source_columns (dict): Mapping of names to source columns 478 Returns: 479 dict: Mapping of column name to source name 480 """ 481 if not source_columns: 482 return {} 483 484 source_columns = list(source_columns.items()) 485 486 first_table, first_columns = source_columns[0] 487 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 488 all_columns = set(unambiguous_columns) 489 490 for table, columns in source_columns[1:]: 491 unique = self._find_unique_columns(columns) 492 ambiguous = set(all_columns).intersection(unique) 493 all_columns.update(columns) 494 for column in ambiguous: 495 unambiguous_columns.pop(column, None) 496 for column in unique.difference(ambiguous): 497 unambiguous_columns[column] = table 498 499 return unambiguous_columns 500 501 @staticmethod 502 def _find_unique_columns(columns): 503 """ 504 Find the unique columns in a list of columns. 505 506 Example: 507 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 508 ['a', 'c'] 509 510 This is necessary because duplicate column names are ambiguous. 511 """ 512 counts = {} 513 for column in columns: 514 counts[column] = counts.get(column, 0) + 1 515 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 using_column_tables = _expand_using(scope, resolver) 35 _qualify_columns(scope, resolver) 36 if not isinstance(scope.expression, exp.UDTF): 37 _expand_stars(scope, resolver, using_column_tables) 38 _qualify_outputs(scope) 39 _expand_alias_refs(scope, resolver) 40 _expand_group_by(scope, resolver) 41 _expand_order_by(scope) 42 43 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
46def validate_qualify_columns(expression): 47 """Raise an `OptimizeError` if any columns aren't qualified""" 48 unqualified_columns = [] 49 for scope in traverse_scope(expression): 50 if isinstance(scope.expression, exp.Select): 51 unqualified_columns.extend(scope.unqualified_columns) 52 if scope.external_columns and not scope.is_correlated_subquery: 53 column = scope.external_columns[0] 54 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 55 56 if unqualified_columns: 57 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 58 return expression
Raise an OptimizeError
if any columns aren't qualified
class
Resolver:
383class Resolver: 384 """ 385 Helper for resolving columns. 386 387 This is a class so we can lazily load some things and easily share them across functions. 388 """ 389 390 def __init__(self, scope, schema): 391 self.scope = scope 392 self.schema = schema 393 self._source_columns = None 394 self._unambiguous_columns = None 395 self._all_columns = None 396 397 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 398 """ 399 Get the table for a column name. 400 401 Args: 402 column_name: The column name to find the table for. 403 Returns: 404 The table name if it can be found/inferred. 405 """ 406 if self._unambiguous_columns is None: 407 self._unambiguous_columns = self._get_unambiguous_columns( 408 self._get_all_source_columns() 409 ) 410 411 table_name = self._unambiguous_columns.get(column_name) 412 413 if not table_name: 414 sources_without_schema = tuple( 415 source 416 for source, columns in self._get_all_source_columns().items() 417 if not columns or "*" in columns 418 ) 419 if len(sources_without_schema) == 1: 420 table_name = sources_without_schema[0] 421 422 if table_name not in self.scope.selected_sources: 423 return exp.to_identifier(table_name) 424 425 node, _ = self.scope.selected_sources.get(table_name) 426 427 if isinstance(node, exp.Subqueryable): 428 while node and node.alias != table_name: 429 node = node.parent 430 431 node_alias = node.args.get("alias") 432 if node_alias: 433 return node_alias.this 434 435 return exp.to_identifier( 436 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 437 ) 438 439 @property 440 def all_columns(self): 441 """All available columns of all sources in this scope""" 442 if self._all_columns is None: 443 self._all_columns = { 444 column for columns in self._get_all_source_columns().values() for column in columns 445 } 446 return self._all_columns 447 448 def get_source_columns(self, name, only_visible=False): 449 """Resolve the source columns for a given source `name`""" 450 if name not in self.scope.sources: 451 raise OptimizeError(f"Unknown table: {name}") 452 453 source = self.scope.sources[name] 454 455 # If referencing a table, return the columns from the schema 456 if isinstance(source, exp.Table): 457 return self.schema.column_names(source, only_visible) 458 459 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 460 return source.expression.alias_column_names 461 462 # Otherwise, if referencing another scope, return that scope's named selects 463 return source.expression.named_selects 464 465 def _get_all_source_columns(self): 466 if self._source_columns is None: 467 self._source_columns = { 468 k: self.get_source_columns(k) 469 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 470 } 471 return self._source_columns 472 473 def _get_unambiguous_columns(self, source_columns): 474 """ 475 Find all the unambiguous columns in sources. 476 477 Args: 478 source_columns (dict): Mapping of names to source columns 479 Returns: 480 dict: Mapping of column name to source name 481 """ 482 if not source_columns: 483 return {} 484 485 source_columns = list(source_columns.items()) 486 487 first_table, first_columns = source_columns[0] 488 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 489 all_columns = set(unambiguous_columns) 490 491 for table, columns in source_columns[1:]: 492 unique = self._find_unique_columns(columns) 493 ambiguous = set(all_columns).intersection(unique) 494 all_columns.update(columns) 495 for column in ambiguous: 496 unambiguous_columns.pop(column, None) 497 for column in unique.difference(ambiguous): 498 unambiguous_columns[column] = table 499 500 return unambiguous_columns 501 502 @staticmethod 503 def _find_unique_columns(columns): 504 """ 505 Find the unique columns in a list of columns. 506 507 Example: 508 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 509 ['a', 'c'] 510 511 This is necessary because duplicate column names are ambiguous. 512 """ 513 counts = {} 514 for column in columns: 515 counts[column] = counts.get(column, 0) + 1 516 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
397 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 398 """ 399 Get the table for a column name. 400 401 Args: 402 column_name: The column name to find the table for. 403 Returns: 404 The table name if it can be found/inferred. 405 """ 406 if self._unambiguous_columns is None: 407 self._unambiguous_columns = self._get_unambiguous_columns( 408 self._get_all_source_columns() 409 ) 410 411 table_name = self._unambiguous_columns.get(column_name) 412 413 if not table_name: 414 sources_without_schema = tuple( 415 source 416 for source, columns in self._get_all_source_columns().items() 417 if not columns or "*" in columns 418 ) 419 if len(sources_without_schema) == 1: 420 table_name = sources_without_schema[0] 421 422 if table_name not in self.scope.selected_sources: 423 return exp.to_identifier(table_name) 424 425 node, _ = self.scope.selected_sources.get(table_name) 426 427 if isinstance(node, exp.Subqueryable): 428 while node and node.alias != table_name: 429 node = node.parent 430 431 node_alias = node.args.get("alias") 432 if node_alias: 433 return node_alias.this 434 435 return exp.to_identifier( 436 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 437 )
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
448 def get_source_columns(self, name, only_visible=False): 449 """Resolve the source columns for a given source `name`""" 450 if name not in self.scope.sources: 451 raise OptimizeError(f"Unknown table: {name}") 452 453 source = self.scope.sources[name] 454 455 # If referencing a table, return the columns from the schema 456 if isinstance(source, exp.Table): 457 return self.schema.column_names(source, only_visible) 458 459 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 460 return source.expression.alias_column_names 461 462 # Otherwise, if referencing another scope, return that scope's named selects 463 return source.expression.named_selects
Resolve the source columns for a given source name