sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 using_column_tables = _expand_using(scope, resolver) 34 _qualify_columns(scope, resolver) 35 if not isinstance(scope.expression, exp.UDTF): 36 _expand_stars(scope, resolver, using_column_tables) 37 _qualify_outputs(scope) 38 _expand_alias_refs(scope, resolver) 39 _expand_group_by(scope, resolver) 40 _expand_order_by(scope) 41 42 return expression 43 44 45def validate_qualify_columns(expression): 46 """Raise an `OptimizeError` if any columns aren't qualified""" 47 unqualified_columns = [] 48 for scope in traverse_scope(expression): 49 if isinstance(scope.expression, exp.Select): 50 unqualified_columns.extend(scope.unqualified_columns) 51 if scope.external_columns and not scope.is_correlated_subquery: 52 column = scope.external_columns[0] 53 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 54 55 if unqualified_columns: 56 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 57 return expression 58 59 60def _pop_table_column_aliases(derived_tables): 61 """ 62 Remove table column aliases. 63 64 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 65 """ 66 for derived_table in derived_tables: 67 table_alias = derived_table.args.get("alias") 68 if table_alias: 69 table_alias.args.pop("columns", None) 70 71 72def _expand_using(scope, resolver): 73 joins = list(scope.find_all(exp.Join)) 74 names = {join.this.alias for join in joins} 75 ordered = [key for key in scope.selected_sources if key not in names] 76 77 # Mapping of automatically joined column names to an ordered set of source names (dict). 78 column_tables = {} 79 80 for join in joins: 81 using = join.args.get("using") 82 83 if not using: 84 continue 85 86 join_table = join.this.alias_or_name 87 88 columns = {} 89 90 for k in scope.selected_sources: 91 if k in ordered: 92 for column in resolver.get_source_columns(k): 93 if column not in columns: 94 columns[column] = k 95 96 source_table = ordered[-1] 97 ordered.append(join_table) 98 join_columns = resolver.get_source_columns(join_table) 99 conditions = [] 100 101 for identifier in using: 102 identifier = identifier.name 103 table = columns.get(identifier) 104 105 if not table or identifier not in join_columns: 106 if columns and join_columns: 107 raise OptimizeError(f"Cannot automatically join: {identifier}") 108 109 table = table or source_table 110 conditions.append( 111 exp.condition( 112 exp.EQ( 113 this=exp.column(identifier, table=table), 114 expression=exp.column(identifier, table=join_table), 115 ) 116 ) 117 ) 118 119 # Set all values in the dict to None, because we only care about the key ordering 120 tables = column_tables.setdefault(identifier, {}) 121 if table not in tables: 122 tables[table] = None 123 if join_table not in tables: 124 tables[join_table] = None 125 126 join.args.pop("using") 127 join.set("on", exp.and_(*conditions)) 128 129 if column_tables: 130 for column in scope.columns: 131 if not column.table and column.name in column_tables: 132 tables = column_tables[column.name] 133 coalesce = [exp.column(column.name, table=table) for table in tables] 134 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 135 136 # Ensure selects keep their output name 137 if isinstance(column.parent, exp.Select): 138 replacement = exp.alias_(replacement, alias=column.name) 139 140 scope.replace(column, replacement) 141 142 return column_tables 143 144 145def _expand_alias_refs(scope, resolver): 146 selects = {} 147 148 # Replace references to select aliases 149 def transform(node, source_first=True): 150 if isinstance(node, exp.Column) and not node.table: 151 table = resolver.get_table(node.name) 152 153 # Source columns get priority over select aliases 154 if source_first and table: 155 node.set("table", table) 156 return node 157 158 if not selects: 159 for s in scope.selects: 160 selects[s.alias_or_name] = s 161 select = selects.get(node.name) 162 163 if select: 164 scope.clear_cache() 165 if isinstance(select, exp.Alias): 166 select = select.this 167 return select.copy() 168 169 node.set("table", table) 170 elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable): 171 exp.replace_children(node, transform, source_first) 172 173 return node 174 175 for select in scope.expression.selects: 176 transform(select) 177 178 for modifier, source_first in ( 179 ("where", True), 180 ("group", True), 181 ("having", False), 182 ): 183 transform(scope.expression.args.get(modifier), source_first=source_first) 184 185 186def _expand_group_by(scope, resolver): 187 group = scope.expression.args.get("group") 188 if not group: 189 return 190 191 group.set("expressions", _expand_positional_references(scope, group.expressions)) 192 scope.expression.set("group", group) 193 194 195def _expand_order_by(scope): 196 order = scope.expression.args.get("order") 197 if not order: 198 return 199 200 ordereds = order.expressions 201 for ordered, new_expression in zip( 202 ordereds, 203 _expand_positional_references(scope, (o.this for o in ordereds)), 204 ): 205 ordered.set("this", new_expression) 206 207 208def _expand_positional_references(scope, expressions): 209 new_nodes = [] 210 for node in expressions: 211 if node.is_int: 212 try: 213 select = scope.selects[int(node.name) - 1] 214 except IndexError: 215 raise OptimizeError(f"Unknown output column: {node.name}") 216 if isinstance(select, exp.Alias): 217 select = select.this 218 new_nodes.append(select.copy()) 219 scope.clear_cache() 220 else: 221 new_nodes.append(node) 222 223 return new_nodes 224 225 226def _qualify_columns(scope, resolver): 227 """Disambiguate columns, ensuring each column specifies a source""" 228 for column in scope.columns: 229 column_table = column.table 230 column_name = column.name 231 232 if column_table and column_table in scope.sources: 233 source_columns = resolver.get_source_columns(column_table) 234 if source_columns and column_name not in source_columns and "*" not in source_columns: 235 raise OptimizeError(f"Unknown column: {column_name}") 236 237 if not column_table: 238 column_table = resolver.get_table(column_name) 239 240 # column_table can be a '' because bigquery unnest has no table alias 241 if column_table: 242 column.set("table", column_table) 243 elif column_table not in scope.sources: 244 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 245 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 246 247 root, *parts = column.parts 248 249 if root.name in scope.sources: 250 # struct is already qualified, but we still need to change the AST representation 251 column_table = root 252 root, *parts = parts 253 else: 254 column_table = resolver.get_table(root.name) 255 256 if column_table: 257 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 258 259 columns_missing_from_scope = [] 260 261 # Determine whether each reference in the order by clause is to a column or an alias. 262 order = scope.expression.args.get("order") 263 264 if order: 265 for ordered in order.expressions: 266 for column in ordered.find_all(exp.Column): 267 if ( 268 not column.table 269 and column.parent is not ordered 270 and column.name in resolver.all_columns 271 ): 272 columns_missing_from_scope.append(column) 273 274 # Determine whether each reference in the having clause is to a column or an alias. 275 having = scope.expression.args.get("having") 276 277 if having: 278 for column in having.find_all(exp.Column): 279 if ( 280 not column.table 281 and column.find_ancestor(exp.AggFunc) 282 and column.name in resolver.all_columns 283 ): 284 columns_missing_from_scope.append(column) 285 286 for column in columns_missing_from_scope: 287 column_table = resolver.get_table(column.name) 288 289 if column_table: 290 column.set("table", column_table) 291 292 293def _expand_stars(scope, resolver, using_column_tables): 294 """Expand stars to lists of column selections""" 295 296 new_selections = [] 297 except_columns = {} 298 replace_columns = {} 299 coalesced_columns = set() 300 301 for expression in scope.selects: 302 if isinstance(expression, exp.Star): 303 tables = list(scope.selected_sources) 304 _add_except_columns(expression, tables, except_columns) 305 _add_replace_columns(expression, tables, replace_columns) 306 elif expression.is_star: 307 tables = [expression.table] 308 _add_except_columns(expression.this, tables, except_columns) 309 _add_replace_columns(expression.this, tables, replace_columns) 310 else: 311 new_selections.append(expression) 312 continue 313 314 for table in tables: 315 if table not in scope.sources: 316 raise OptimizeError(f"Unknown table: {table}") 317 columns = resolver.get_source_columns(table, only_visible=True) 318 319 if columns and "*" not in columns: 320 table_id = id(table) 321 for name in columns: 322 if name in using_column_tables and table in using_column_tables[name]: 323 if name in coalesced_columns: 324 continue 325 326 coalesced_columns.add(name) 327 tables = using_column_tables[name] 328 coalesce = [exp.column(name, table=table) for table in tables] 329 330 new_selections.append( 331 exp.alias_( 332 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name 333 ) 334 ) 335 elif name not in except_columns.get(table_id, set()): 336 alias_ = replace_columns.get(table_id, {}).get(name, name) 337 column = exp.column(name, table) 338 new_selections.append(alias(column, alias_) if alias_ != name else column) 339 else: 340 return 341 scope.expression.set("expressions", new_selections) 342 343 344def _add_except_columns(expression, tables, except_columns): 345 except_ = expression.args.get("except") 346 347 if not except_: 348 return 349 350 columns = {e.name for e in except_} 351 352 for table in tables: 353 except_columns[id(table)] = columns 354 355 356def _add_replace_columns(expression, tables, replace_columns): 357 replace = expression.args.get("replace") 358 359 if not replace: 360 return 361 362 columns = {e.this.name: e.alias for e in replace} 363 364 for table in tables: 365 replace_columns[id(table)] = columns 366 367 368def _qualify_outputs(scope): 369 """Ensure all output columns are aliased""" 370 new_selections = [] 371 372 for i, (selection, aliased_column) in enumerate( 373 itertools.zip_longest(scope.selects, scope.outer_column_list) 374 ): 375 if isinstance(selection, exp.Subquery): 376 if not selection.output_name: 377 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 378 elif not isinstance(selection, exp.Alias) and not selection.is_star: 379 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 380 alias_.set("this", selection) 381 selection = alias_ 382 383 if aliased_column: 384 selection.set("alias", exp.to_identifier(aliased_column)) 385 386 new_selections.append(selection) 387 388 scope.expression.set("expressions", new_selections) 389 390 391class Resolver: 392 """ 393 Helper for resolving columns. 394 395 This is a class so we can lazily load some things and easily share them across functions. 396 """ 397 398 def __init__(self, scope, schema): 399 self.scope = scope 400 self.schema = schema 401 self._source_columns = None 402 self._unambiguous_columns = None 403 self._all_columns = None 404 405 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 406 """ 407 Get the table for a column name. 408 409 Args: 410 column_name: The column name to find the table for. 411 Returns: 412 The table name if it can be found/inferred. 413 """ 414 if self._unambiguous_columns is None: 415 self._unambiguous_columns = self._get_unambiguous_columns( 416 self._get_all_source_columns() 417 ) 418 419 table_name = self._unambiguous_columns.get(column_name) 420 421 if not table_name: 422 sources_without_schema = tuple( 423 source 424 for source, columns in self._get_all_source_columns().items() 425 if not columns or "*" in columns 426 ) 427 if len(sources_without_schema) == 1: 428 table_name = sources_without_schema[0] 429 430 if table_name not in self.scope.selected_sources: 431 return exp.to_identifier(table_name) 432 433 node, _ = self.scope.selected_sources.get(table_name) 434 435 if isinstance(node, exp.Subqueryable): 436 while node and node.alias != table_name: 437 node = node.parent 438 439 node_alias = node.args.get("alias") 440 if node_alias: 441 return node_alias.this 442 443 return exp.to_identifier( 444 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 445 ) 446 447 @property 448 def all_columns(self): 449 """All available columns of all sources in this scope""" 450 if self._all_columns is None: 451 self._all_columns = { 452 column for columns in self._get_all_source_columns().values() for column in columns 453 } 454 return self._all_columns 455 456 def get_source_columns(self, name, only_visible=False): 457 """Resolve the source columns for a given source `name`""" 458 if name not in self.scope.sources: 459 raise OptimizeError(f"Unknown table: {name}") 460 461 source = self.scope.sources[name] 462 463 # If referencing a table, return the columns from the schema 464 if isinstance(source, exp.Table): 465 return self.schema.column_names(source, only_visible) 466 467 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 468 return source.expression.alias_column_names 469 470 # Otherwise, if referencing another scope, return that scope's named selects 471 return source.expression.named_selects 472 473 def _get_all_source_columns(self): 474 if self._source_columns is None: 475 self._source_columns = { 476 k: self.get_source_columns(k) 477 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 478 } 479 return self._source_columns 480 481 def _get_unambiguous_columns(self, source_columns): 482 """ 483 Find all the unambiguous columns in sources. 484 485 Args: 486 source_columns (dict): Mapping of names to source columns 487 Returns: 488 dict: Mapping of column name to source name 489 """ 490 if not source_columns: 491 return {} 492 493 source_columns = list(source_columns.items()) 494 495 first_table, first_columns = source_columns[0] 496 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 497 all_columns = set(unambiguous_columns) 498 499 for table, columns in source_columns[1:]: 500 unique = self._find_unique_columns(columns) 501 ambiguous = set(all_columns).intersection(unique) 502 all_columns.update(columns) 503 for column in ambiguous: 504 unambiguous_columns.pop(column, None) 505 for column in unique.difference(ambiguous): 506 unambiguous_columns[column] = table 507 508 return unambiguous_columns 509 510 @staticmethod 511 def _find_unique_columns(columns): 512 """ 513 Find the unique columns in a list of columns. 514 515 Example: 516 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 517 ['a', 'c'] 518 519 This is necessary because duplicate column names are ambiguous. 520 """ 521 counts = {} 522 for column in columns: 523 counts[column] = counts.get(column, 0) + 1 524 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 using_column_tables = _expand_using(scope, resolver) 35 _qualify_columns(scope, resolver) 36 if not isinstance(scope.expression, exp.UDTF): 37 _expand_stars(scope, resolver, using_column_tables) 38 _qualify_outputs(scope) 39 _expand_alias_refs(scope, resolver) 40 _expand_group_by(scope, resolver) 41 _expand_order_by(scope) 42 43 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
46def validate_qualify_columns(expression): 47 """Raise an `OptimizeError` if any columns aren't qualified""" 48 unqualified_columns = [] 49 for scope in traverse_scope(expression): 50 if isinstance(scope.expression, exp.Select): 51 unqualified_columns.extend(scope.unqualified_columns) 52 if scope.external_columns and not scope.is_correlated_subquery: 53 column = scope.external_columns[0] 54 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 55 56 if unqualified_columns: 57 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 58 return expression
Raise an OptimizeError
if any columns aren't qualified
class
Resolver:
392class Resolver: 393 """ 394 Helper for resolving columns. 395 396 This is a class so we can lazily load some things and easily share them across functions. 397 """ 398 399 def __init__(self, scope, schema): 400 self.scope = scope 401 self.schema = schema 402 self._source_columns = None 403 self._unambiguous_columns = None 404 self._all_columns = None 405 406 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 407 """ 408 Get the table for a column name. 409 410 Args: 411 column_name: The column name to find the table for. 412 Returns: 413 The table name if it can be found/inferred. 414 """ 415 if self._unambiguous_columns is None: 416 self._unambiguous_columns = self._get_unambiguous_columns( 417 self._get_all_source_columns() 418 ) 419 420 table_name = self._unambiguous_columns.get(column_name) 421 422 if not table_name: 423 sources_without_schema = tuple( 424 source 425 for source, columns in self._get_all_source_columns().items() 426 if not columns or "*" in columns 427 ) 428 if len(sources_without_schema) == 1: 429 table_name = sources_without_schema[0] 430 431 if table_name not in self.scope.selected_sources: 432 return exp.to_identifier(table_name) 433 434 node, _ = self.scope.selected_sources.get(table_name) 435 436 if isinstance(node, exp.Subqueryable): 437 while node and node.alias != table_name: 438 node = node.parent 439 440 node_alias = node.args.get("alias") 441 if node_alias: 442 return node_alias.this 443 444 return exp.to_identifier( 445 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 446 ) 447 448 @property 449 def all_columns(self): 450 """All available columns of all sources in this scope""" 451 if self._all_columns is None: 452 self._all_columns = { 453 column for columns in self._get_all_source_columns().values() for column in columns 454 } 455 return self._all_columns 456 457 def get_source_columns(self, name, only_visible=False): 458 """Resolve the source columns for a given source `name`""" 459 if name not in self.scope.sources: 460 raise OptimizeError(f"Unknown table: {name}") 461 462 source = self.scope.sources[name] 463 464 # If referencing a table, return the columns from the schema 465 if isinstance(source, exp.Table): 466 return self.schema.column_names(source, only_visible) 467 468 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 469 return source.expression.alias_column_names 470 471 # Otherwise, if referencing another scope, return that scope's named selects 472 return source.expression.named_selects 473 474 def _get_all_source_columns(self): 475 if self._source_columns is None: 476 self._source_columns = { 477 k: self.get_source_columns(k) 478 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 479 } 480 return self._source_columns 481 482 def _get_unambiguous_columns(self, source_columns): 483 """ 484 Find all the unambiguous columns in sources. 485 486 Args: 487 source_columns (dict): Mapping of names to source columns 488 Returns: 489 dict: Mapping of column name to source name 490 """ 491 if not source_columns: 492 return {} 493 494 source_columns = list(source_columns.items()) 495 496 first_table, first_columns = source_columns[0] 497 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 498 all_columns = set(unambiguous_columns) 499 500 for table, columns in source_columns[1:]: 501 unique = self._find_unique_columns(columns) 502 ambiguous = set(all_columns).intersection(unique) 503 all_columns.update(columns) 504 for column in ambiguous: 505 unambiguous_columns.pop(column, None) 506 for column in unique.difference(ambiguous): 507 unambiguous_columns[column] = table 508 509 return unambiguous_columns 510 511 @staticmethod 512 def _find_unique_columns(columns): 513 """ 514 Find the unique columns in a list of columns. 515 516 Example: 517 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 518 ['a', 'c'] 519 520 This is necessary because duplicate column names are ambiguous. 521 """ 522 counts = {} 523 for column in columns: 524 counts[column] = counts.get(column, 0) + 1 525 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
406 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 407 """ 408 Get the table for a column name. 409 410 Args: 411 column_name: The column name to find the table for. 412 Returns: 413 The table name if it can be found/inferred. 414 """ 415 if self._unambiguous_columns is None: 416 self._unambiguous_columns = self._get_unambiguous_columns( 417 self._get_all_source_columns() 418 ) 419 420 table_name = self._unambiguous_columns.get(column_name) 421 422 if not table_name: 423 sources_without_schema = tuple( 424 source 425 for source, columns in self._get_all_source_columns().items() 426 if not columns or "*" in columns 427 ) 428 if len(sources_without_schema) == 1: 429 table_name = sources_without_schema[0] 430 431 if table_name not in self.scope.selected_sources: 432 return exp.to_identifier(table_name) 433 434 node, _ = self.scope.selected_sources.get(table_name) 435 436 if isinstance(node, exp.Subqueryable): 437 while node and node.alias != table_name: 438 node = node.parent 439 440 node_alias = node.args.get("alias") 441 if node_alias: 442 return node_alias.this 443 444 return exp.to_identifier( 445 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 446 )
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
457 def get_source_columns(self, name, only_visible=False): 458 """Resolve the source columns for a given source `name`""" 459 if name not in self.scope.sources: 460 raise OptimizeError(f"Unknown table: {name}") 461 462 source = self.scope.sources[name] 463 464 # If referencing a table, return the columns from the schema 465 if isinstance(source, exp.Table): 466 return self.schema.column_names(source, only_visible) 467 468 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 469 return source.expression.alias_column_names 470 471 # Otherwise, if referencing another scope, return that scope's named selects 472 return source.expression.named_selects
Resolve the source columns for a given source name