sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 using_column_tables = _expand_using(scope, resolver) 34 _qualify_columns(scope, resolver) 35 if not isinstance(scope.expression, exp.UDTF): 36 _expand_stars(scope, resolver, using_column_tables) 37 _qualify_outputs(scope) 38 _expand_alias_refs(scope, resolver) 39 _expand_group_by(scope, resolver) 40 _expand_order_by(scope) 41 42 return expression 43 44 45def validate_qualify_columns(expression): 46 """Raise an `OptimizeError` if any columns aren't qualified""" 47 unqualified_columns = [] 48 for scope in traverse_scope(expression): 49 if isinstance(scope.expression, exp.Select): 50 unqualified_columns.extend(scope.unqualified_columns) 51 if scope.external_columns and not scope.is_correlated_subquery: 52 column = scope.external_columns[0] 53 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 54 55 if unqualified_columns: 56 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 57 return expression 58 59 60def _pop_table_column_aliases(derived_tables): 61 """ 62 Remove table column aliases. 63 64 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 65 """ 66 for derived_table in derived_tables: 67 table_alias = derived_table.args.get("alias") 68 if table_alias: 69 table_alias.args.pop("columns", None) 70 71 72def _expand_using(scope, resolver): 73 joins = list(scope.find_all(exp.Join)) 74 names = {join.this.alias for join in joins} 75 ordered = [key for key in scope.selected_sources if key not in names] 76 77 # Mapping of automatically joined column names to an ordered set of source names (dict). 78 column_tables = {} 79 80 for join in joins: 81 using = join.args.get("using") 82 83 if not using: 84 continue 85 86 join_table = join.this.alias_or_name 87 88 columns = {} 89 90 for k in scope.selected_sources: 91 if k in ordered: 92 for column in resolver.get_source_columns(k): 93 if column not in columns: 94 columns[column] = k 95 96 ordered.append(join_table) 97 join_columns = resolver.get_source_columns(join_table) 98 conditions = [] 99 100 for identifier in using: 101 identifier = identifier.name 102 table = columns.get(identifier) 103 104 if not table or identifier not in join_columns: 105 raise OptimizeError(f"Cannot automatically join: {identifier}") 106 107 conditions.append( 108 exp.condition( 109 exp.EQ( 110 this=exp.column(identifier, table=table), 111 expression=exp.column(identifier, table=join_table), 112 ) 113 ) 114 ) 115 116 # Set all values in the dict to None, because we only care about the key ordering 117 tables = column_tables.setdefault(identifier, {}) 118 if table not in tables: 119 tables[table] = None 120 if join_table not in tables: 121 tables[join_table] = None 122 123 join.args.pop("using") 124 join.set("on", exp.and_(*conditions)) 125 126 if column_tables: 127 for column in scope.columns: 128 if not column.table and column.name in column_tables: 129 tables = column_tables[column.name] 130 coalesce = [exp.column(column.name, table=table) for table in tables] 131 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 132 133 # Ensure selects keep their output name 134 if isinstance(column.parent, exp.Select): 135 replacement = exp.alias_(replacement, alias=column.name) 136 137 scope.replace(column, replacement) 138 139 return column_tables 140 141 142def _expand_alias_refs(scope, resolver): 143 selects = {} 144 145 # Replace references to select aliases 146 def transform(node, source_first=True): 147 if isinstance(node, exp.Column) and not node.table: 148 table = resolver.get_table(node.name) 149 150 # Source columns get priority over select aliases 151 if source_first and table: 152 node.set("table", table) 153 return node 154 155 if not selects: 156 for s in scope.selects: 157 selects[s.alias_or_name] = s 158 select = selects.get(node.name) 159 160 if select: 161 scope.clear_cache() 162 if isinstance(select, exp.Alias): 163 select = select.this 164 return select.copy() 165 166 node.set("table", table) 167 elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable): 168 exp.replace_children(node, transform, source_first) 169 170 return node 171 172 for select in scope.expression.selects: 173 transform(select) 174 175 for modifier, source_first in ( 176 ("where", True), 177 ("group", True), 178 ("having", False), 179 ): 180 transform(scope.expression.args.get(modifier), source_first=source_first) 181 182 183def _expand_group_by(scope, resolver): 184 group = scope.expression.args.get("group") 185 if not group: 186 return 187 188 group.set("expressions", _expand_positional_references(scope, group.expressions)) 189 scope.expression.set("group", group) 190 191 192def _expand_order_by(scope): 193 order = scope.expression.args.get("order") 194 if not order: 195 return 196 197 ordereds = order.expressions 198 for ordered, new_expression in zip( 199 ordereds, 200 _expand_positional_references(scope, (o.this for o in ordereds)), 201 ): 202 ordered.set("this", new_expression) 203 204 205def _expand_positional_references(scope, expressions): 206 new_nodes = [] 207 for node in expressions: 208 if node.is_int: 209 try: 210 select = scope.selects[int(node.name) - 1] 211 except IndexError: 212 raise OptimizeError(f"Unknown output column: {node.name}") 213 if isinstance(select, exp.Alias): 214 select = select.this 215 new_nodes.append(select.copy()) 216 scope.clear_cache() 217 else: 218 new_nodes.append(node) 219 220 return new_nodes 221 222 223def _qualify_columns(scope, resolver): 224 """Disambiguate columns, ensuring each column specifies a source""" 225 for column in scope.columns: 226 column_table = column.table 227 column_name = column.name 228 229 if column_table and column_table in scope.sources: 230 source_columns = resolver.get_source_columns(column_table) 231 if source_columns and column_name not in source_columns and "*" not in source_columns: 232 raise OptimizeError(f"Unknown column: {column_name}") 233 234 if not column_table: 235 column_table = resolver.get_table(column_name) 236 237 # column_table can be a '' because bigquery unnest has no table alias 238 if column_table: 239 column.set("table", column_table) 240 elif column_table not in scope.sources: 241 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 242 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 243 244 root, *parts = column.parts 245 246 if root.name in scope.sources: 247 # struct is already qualified, but we still need to change the AST representation 248 column_table = root 249 root, *parts = parts 250 else: 251 column_table = resolver.get_table(root.name) 252 253 if column_table: 254 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 255 256 columns_missing_from_scope = [] 257 258 # Determine whether each reference in the order by clause is to a column or an alias. 259 order = scope.expression.args.get("order") 260 261 if order: 262 for ordered in order.expressions: 263 for column in ordered.find_all(exp.Column): 264 if ( 265 not column.table 266 and column.parent is not ordered 267 and column.name in resolver.all_columns 268 ): 269 columns_missing_from_scope.append(column) 270 271 # Determine whether each reference in the having clause is to a column or an alias. 272 having = scope.expression.args.get("having") 273 274 if having: 275 for column in having.find_all(exp.Column): 276 if ( 277 not column.table 278 and column.find_ancestor(exp.AggFunc) 279 and column.name in resolver.all_columns 280 ): 281 columns_missing_from_scope.append(column) 282 283 for column in columns_missing_from_scope: 284 column_table = resolver.get_table(column.name) 285 286 if column_table: 287 column.set("table", column_table) 288 289 290def _expand_stars(scope, resolver, using_column_tables): 291 """Expand stars to lists of column selections""" 292 293 new_selections = [] 294 except_columns = {} 295 replace_columns = {} 296 coalesced_columns = set() 297 298 for expression in scope.selects: 299 if isinstance(expression, exp.Star): 300 tables = list(scope.selected_sources) 301 _add_except_columns(expression, tables, except_columns) 302 _add_replace_columns(expression, tables, replace_columns) 303 elif expression.is_star: 304 tables = [expression.table] 305 _add_except_columns(expression.this, tables, except_columns) 306 _add_replace_columns(expression.this, tables, replace_columns) 307 else: 308 new_selections.append(expression) 309 continue 310 311 for table in tables: 312 if table not in scope.sources: 313 raise OptimizeError(f"Unknown table: {table}") 314 columns = resolver.get_source_columns(table, only_visible=True) 315 316 if columns and "*" not in columns: 317 table_id = id(table) 318 for name in columns: 319 if name in using_column_tables and table in using_column_tables[name]: 320 if name in coalesced_columns: 321 continue 322 323 coalesced_columns.add(name) 324 tables = using_column_tables[name] 325 coalesce = [exp.column(name, table=table) for table in tables] 326 327 new_selections.append( 328 exp.alias_( 329 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name 330 ) 331 ) 332 elif name not in except_columns.get(table_id, set()): 333 alias_ = replace_columns.get(table_id, {}).get(name, name) 334 column = exp.column(name, table) 335 new_selections.append(alias(column, alias_) if alias_ != name else column) 336 else: 337 return 338 scope.expression.set("expressions", new_selections) 339 340 341def _add_except_columns(expression, tables, except_columns): 342 except_ = expression.args.get("except") 343 344 if not except_: 345 return 346 347 columns = {e.name for e in except_} 348 349 for table in tables: 350 except_columns[id(table)] = columns 351 352 353def _add_replace_columns(expression, tables, replace_columns): 354 replace = expression.args.get("replace") 355 356 if not replace: 357 return 358 359 columns = {e.this.name: e.alias for e in replace} 360 361 for table in tables: 362 replace_columns[id(table)] = columns 363 364 365def _qualify_outputs(scope): 366 """Ensure all output columns are aliased""" 367 new_selections = [] 368 369 for i, (selection, aliased_column) in enumerate( 370 itertools.zip_longest(scope.selects, scope.outer_column_list) 371 ): 372 if isinstance(selection, exp.Subquery): 373 if not selection.output_name: 374 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 375 elif not isinstance(selection, exp.Alias) and not selection.is_star: 376 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 377 alias_.set("this", selection) 378 selection = alias_ 379 380 if aliased_column: 381 selection.set("alias", exp.to_identifier(aliased_column)) 382 383 new_selections.append(selection) 384 385 scope.expression.set("expressions", new_selections) 386 387 388class Resolver: 389 """ 390 Helper for resolving columns. 391 392 This is a class so we can lazily load some things and easily share them across functions. 393 """ 394 395 def __init__(self, scope, schema): 396 self.scope = scope 397 self.schema = schema 398 self._source_columns = None 399 self._unambiguous_columns = None 400 self._all_columns = None 401 402 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 403 """ 404 Get the table for a column name. 405 406 Args: 407 column_name: The column name to find the table for. 408 Returns: 409 The table name if it can be found/inferred. 410 """ 411 if self._unambiguous_columns is None: 412 self._unambiguous_columns = self._get_unambiguous_columns( 413 self._get_all_source_columns() 414 ) 415 416 table_name = self._unambiguous_columns.get(column_name) 417 418 if not table_name: 419 sources_without_schema = tuple( 420 source 421 for source, columns in self._get_all_source_columns().items() 422 if not columns or "*" in columns 423 ) 424 if len(sources_without_schema) == 1: 425 table_name = sources_without_schema[0] 426 427 if table_name not in self.scope.selected_sources: 428 return exp.to_identifier(table_name) 429 430 node, _ = self.scope.selected_sources.get(table_name) 431 432 if isinstance(node, exp.Subqueryable): 433 while node and node.alias != table_name: 434 node = node.parent 435 436 node_alias = node.args.get("alias") 437 if node_alias: 438 return node_alias.this 439 440 return exp.to_identifier( 441 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 442 ) 443 444 @property 445 def all_columns(self): 446 """All available columns of all sources in this scope""" 447 if self._all_columns is None: 448 self._all_columns = { 449 column for columns in self._get_all_source_columns().values() for column in columns 450 } 451 return self._all_columns 452 453 def get_source_columns(self, name, only_visible=False): 454 """Resolve the source columns for a given source `name`""" 455 if name not in self.scope.sources: 456 raise OptimizeError(f"Unknown table: {name}") 457 458 source = self.scope.sources[name] 459 460 # If referencing a table, return the columns from the schema 461 if isinstance(source, exp.Table): 462 return self.schema.column_names(source, only_visible) 463 464 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 465 return source.expression.alias_column_names 466 467 # Otherwise, if referencing another scope, return that scope's named selects 468 return source.expression.named_selects 469 470 def _get_all_source_columns(self): 471 if self._source_columns is None: 472 self._source_columns = { 473 k: self.get_source_columns(k) 474 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 475 } 476 return self._source_columns 477 478 def _get_unambiguous_columns(self, source_columns): 479 """ 480 Find all the unambiguous columns in sources. 481 482 Args: 483 source_columns (dict): Mapping of names to source columns 484 Returns: 485 dict: Mapping of column name to source name 486 """ 487 if not source_columns: 488 return {} 489 490 source_columns = list(source_columns.items()) 491 492 first_table, first_columns = source_columns[0] 493 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 494 all_columns = set(unambiguous_columns) 495 496 for table, columns in source_columns[1:]: 497 unique = self._find_unique_columns(columns) 498 ambiguous = set(all_columns).intersection(unique) 499 all_columns.update(columns) 500 for column in ambiguous: 501 unambiguous_columns.pop(column, None) 502 for column in unique.difference(ambiguous): 503 unambiguous_columns[column] = table 504 505 return unambiguous_columns 506 507 @staticmethod 508 def _find_unique_columns(columns): 509 """ 510 Find the unique columns in a list of columns. 511 512 Example: 513 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 514 ['a', 'c'] 515 516 This is necessary because duplicate column names are ambiguous. 517 """ 518 counts = {} 519 for column in columns: 520 counts[column] = counts.get(column, 0) + 1 521 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 using_column_tables = _expand_using(scope, resolver) 35 _qualify_columns(scope, resolver) 36 if not isinstance(scope.expression, exp.UDTF): 37 _expand_stars(scope, resolver, using_column_tables) 38 _qualify_outputs(scope) 39 _expand_alias_refs(scope, resolver) 40 _expand_group_by(scope, resolver) 41 _expand_order_by(scope) 42 43 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
46def validate_qualify_columns(expression): 47 """Raise an `OptimizeError` if any columns aren't qualified""" 48 unqualified_columns = [] 49 for scope in traverse_scope(expression): 50 if isinstance(scope.expression, exp.Select): 51 unqualified_columns.extend(scope.unqualified_columns) 52 if scope.external_columns and not scope.is_correlated_subquery: 53 column = scope.external_columns[0] 54 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 55 56 if unqualified_columns: 57 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 58 return expression
Raise an OptimizeError
if any columns aren't qualified
class
Resolver:
389class Resolver: 390 """ 391 Helper for resolving columns. 392 393 This is a class so we can lazily load some things and easily share them across functions. 394 """ 395 396 def __init__(self, scope, schema): 397 self.scope = scope 398 self.schema = schema 399 self._source_columns = None 400 self._unambiguous_columns = None 401 self._all_columns = None 402 403 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 404 """ 405 Get the table for a column name. 406 407 Args: 408 column_name: The column name to find the table for. 409 Returns: 410 The table name if it can be found/inferred. 411 """ 412 if self._unambiguous_columns is None: 413 self._unambiguous_columns = self._get_unambiguous_columns( 414 self._get_all_source_columns() 415 ) 416 417 table_name = self._unambiguous_columns.get(column_name) 418 419 if not table_name: 420 sources_without_schema = tuple( 421 source 422 for source, columns in self._get_all_source_columns().items() 423 if not columns or "*" in columns 424 ) 425 if len(sources_without_schema) == 1: 426 table_name = sources_without_schema[0] 427 428 if table_name not in self.scope.selected_sources: 429 return exp.to_identifier(table_name) 430 431 node, _ = self.scope.selected_sources.get(table_name) 432 433 if isinstance(node, exp.Subqueryable): 434 while node and node.alias != table_name: 435 node = node.parent 436 437 node_alias = node.args.get("alias") 438 if node_alias: 439 return node_alias.this 440 441 return exp.to_identifier( 442 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 443 ) 444 445 @property 446 def all_columns(self): 447 """All available columns of all sources in this scope""" 448 if self._all_columns is None: 449 self._all_columns = { 450 column for columns in self._get_all_source_columns().values() for column in columns 451 } 452 return self._all_columns 453 454 def get_source_columns(self, name, only_visible=False): 455 """Resolve the source columns for a given source `name`""" 456 if name not in self.scope.sources: 457 raise OptimizeError(f"Unknown table: {name}") 458 459 source = self.scope.sources[name] 460 461 # If referencing a table, return the columns from the schema 462 if isinstance(source, exp.Table): 463 return self.schema.column_names(source, only_visible) 464 465 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 466 return source.expression.alias_column_names 467 468 # Otherwise, if referencing another scope, return that scope's named selects 469 return source.expression.named_selects 470 471 def _get_all_source_columns(self): 472 if self._source_columns is None: 473 self._source_columns = { 474 k: self.get_source_columns(k) 475 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 476 } 477 return self._source_columns 478 479 def _get_unambiguous_columns(self, source_columns): 480 """ 481 Find all the unambiguous columns in sources. 482 483 Args: 484 source_columns (dict): Mapping of names to source columns 485 Returns: 486 dict: Mapping of column name to source name 487 """ 488 if not source_columns: 489 return {} 490 491 source_columns = list(source_columns.items()) 492 493 first_table, first_columns = source_columns[0] 494 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 495 all_columns = set(unambiguous_columns) 496 497 for table, columns in source_columns[1:]: 498 unique = self._find_unique_columns(columns) 499 ambiguous = set(all_columns).intersection(unique) 500 all_columns.update(columns) 501 for column in ambiguous: 502 unambiguous_columns.pop(column, None) 503 for column in unique.difference(ambiguous): 504 unambiguous_columns[column] = table 505 506 return unambiguous_columns 507 508 @staticmethod 509 def _find_unique_columns(columns): 510 """ 511 Find the unique columns in a list of columns. 512 513 Example: 514 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 515 ['a', 'c'] 516 517 This is necessary because duplicate column names are ambiguous. 518 """ 519 counts = {} 520 for column in columns: 521 counts[column] = counts.get(column, 0) + 1 522 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
403 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 404 """ 405 Get the table for a column name. 406 407 Args: 408 column_name: The column name to find the table for. 409 Returns: 410 The table name if it can be found/inferred. 411 """ 412 if self._unambiguous_columns is None: 413 self._unambiguous_columns = self._get_unambiguous_columns( 414 self._get_all_source_columns() 415 ) 416 417 table_name = self._unambiguous_columns.get(column_name) 418 419 if not table_name: 420 sources_without_schema = tuple( 421 source 422 for source, columns in self._get_all_source_columns().items() 423 if not columns or "*" in columns 424 ) 425 if len(sources_without_schema) == 1: 426 table_name = sources_without_schema[0] 427 428 if table_name not in self.scope.selected_sources: 429 return exp.to_identifier(table_name) 430 431 node, _ = self.scope.selected_sources.get(table_name) 432 433 if isinstance(node, exp.Subqueryable): 434 while node and node.alias != table_name: 435 node = node.parent 436 437 node_alias = node.args.get("alias") 438 if node_alias: 439 return node_alias.this 440 441 return exp.to_identifier( 442 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 443 )
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
454 def get_source_columns(self, name, only_visible=False): 455 """Resolve the source columns for a given source `name`""" 456 if name not in self.scope.sources: 457 raise OptimizeError(f"Unknown table: {name}") 458 459 source = self.scope.sources[name] 460 461 # If referencing a table, return the columns from the schema 462 if isinstance(source, exp.Table): 463 return self.schema.column_names(source, only_visible) 464 465 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 466 return source.expression.alias_column_names 467 468 # Otherwise, if referencing another scope, return that scope's named selects 469 return source.expression.named_selects
Resolve the source columns for a given source name