sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 _expand_using(scope, resolver) 34 _qualify_columns(scope, resolver) 35 if not isinstance(scope.expression, exp.UDTF): 36 _expand_stars(scope, resolver) 37 _qualify_outputs(scope) 38 _expand_group_by(scope, resolver) 39 _expand_order_by(scope) 40 return expression 41 42 43def validate_qualify_columns(expression): 44 """Raise an `OptimizeError` if any columns aren't qualified""" 45 unqualified_columns = [] 46 for scope in traverse_scope(expression): 47 if isinstance(scope.expression, exp.Select): 48 unqualified_columns.extend(scope.unqualified_columns) 49 if scope.external_columns and not scope.is_correlated_subquery: 50 column = scope.external_columns[0] 51 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 52 53 if unqualified_columns: 54 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 55 return expression 56 57 58def _pop_table_column_aliases(derived_tables): 59 """ 60 Remove table column aliases. 61 62 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 63 """ 64 for derived_table in derived_tables: 65 table_alias = derived_table.args.get("alias") 66 if table_alias: 67 table_alias.args.pop("columns", None) 68 69 70def _expand_using(scope, resolver): 71 joins = list(scope.expression.find_all(exp.Join)) 72 names = {join.this.alias for join in joins} 73 ordered = [key for key in scope.selected_sources if key not in names] 74 75 # Mapping of automatically joined column names to source names 76 column_tables = {} 77 78 for join in joins: 79 using = join.args.get("using") 80 81 if not using: 82 continue 83 84 join_table = join.this.alias_or_name 85 86 columns = {} 87 88 for k in scope.selected_sources: 89 if k in ordered: 90 for column in resolver.get_source_columns(k): 91 if column not in columns: 92 columns[column] = k 93 94 ordered.append(join_table) 95 join_columns = resolver.get_source_columns(join_table) 96 conditions = [] 97 98 for identifier in using: 99 identifier = identifier.name 100 table = columns.get(identifier) 101 102 if not table or identifier not in join_columns: 103 raise OptimizeError(f"Cannot automatically join: {identifier}") 104 105 conditions.append( 106 exp.condition( 107 exp.EQ( 108 this=exp.column(identifier, table=table), 109 expression=exp.column(identifier, table=join_table), 110 ) 111 ) 112 ) 113 114 tables = column_tables.setdefault(identifier, []) 115 if table not in tables: 116 tables.append(table) 117 if join_table not in tables: 118 tables.append(join_table) 119 120 join.args.pop("using") 121 join.set("on", exp.and_(*conditions)) 122 123 if column_tables: 124 for column in scope.columns: 125 if not column.table and column.name in column_tables: 126 tables = column_tables[column.name] 127 coalesce = [exp.column(column.name, table=table) for table in tables] 128 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 129 130 # Ensure selects keep their output name 131 if isinstance(column.parent, exp.Select): 132 replacement = exp.alias_(replacement, alias=column.name) 133 134 scope.replace(column, replacement) 135 136 137def _expand_group_by(scope, resolver): 138 group = scope.expression.args.get("group") 139 if not group: 140 return 141 142 # Replace references to select aliases 143 def transform(node, *_): 144 if isinstance(node, exp.Column) and not node.table: 145 table = resolver.get_table(node.name) 146 147 # Source columns get priority over select aliases 148 if table: 149 node.set("table", exp.to_identifier(table)) 150 return node 151 152 selects = {s.alias_or_name: s for s in scope.selects} 153 154 select = selects.get(node.name) 155 if select: 156 scope.clear_cache() 157 if isinstance(select, exp.Alias): 158 select = select.this 159 return select.copy() 160 161 return node 162 163 group.transform(transform, copy=False) 164 group.set("expressions", _expand_positional_references(scope, group.expressions)) 165 scope.expression.set("group", group) 166 167 168def _expand_order_by(scope): 169 order = scope.expression.args.get("order") 170 if not order: 171 return 172 173 ordereds = order.expressions 174 for ordered, new_expression in zip( 175 ordereds, 176 _expand_positional_references(scope, (o.this for o in ordereds)), 177 ): 178 ordered.set("this", new_expression) 179 180 181def _expand_positional_references(scope, expressions): 182 new_nodes = [] 183 for node in expressions: 184 if node.is_int: 185 try: 186 select = scope.selects[int(node.name) - 1] 187 except IndexError: 188 raise OptimizeError(f"Unknown output column: {node.name}") 189 if isinstance(select, exp.Alias): 190 select = select.this 191 new_nodes.append(select.copy()) 192 scope.clear_cache() 193 else: 194 new_nodes.append(node) 195 196 return new_nodes 197 198 199def _qualify_columns(scope, resolver): 200 """Disambiguate columns, ensuring each column specifies a source""" 201 for column in scope.columns: 202 column_table = column.table 203 column_name = column.name 204 205 if column_table and column_table in scope.sources: 206 source_columns = resolver.get_source_columns(column_table) 207 if source_columns and column_name not in source_columns and "*" not in source_columns: 208 raise OptimizeError(f"Unknown column: {column_name}") 209 210 if not column_table: 211 column_table = resolver.get_table(column_name) 212 213 # column_table can be a '' because bigquery unnest has no table alias 214 if column_table: 215 column.set("table", exp.to_identifier(column_table)) 216 217 columns_missing_from_scope = [] 218 # Determine whether each reference in the order by clause is to a column or an alias. 219 for ordered in scope.find_all(exp.Ordered): 220 for column in ordered.find_all(exp.Column): 221 if ( 222 not column.table 223 and column.parent is not ordered 224 and column.name in resolver.all_columns 225 ): 226 columns_missing_from_scope.append(column) 227 228 # Determine whether each reference in the having clause is to a column or an alias. 229 for having in scope.find_all(exp.Having): 230 for column in having.find_all(exp.Column): 231 if ( 232 not column.table 233 and column.find_ancestor(exp.AggFunc) 234 and column.name in resolver.all_columns 235 ): 236 columns_missing_from_scope.append(column) 237 238 for column in columns_missing_from_scope: 239 column_table = resolver.get_table(column.name) 240 241 if column_table: 242 column.set("table", exp.to_identifier(column_table)) 243 244 245def _expand_stars(scope, resolver): 246 """Expand stars to lists of column selections""" 247 248 new_selections = [] 249 except_columns = {} 250 replace_columns = {} 251 252 for expression in scope.selects: 253 if isinstance(expression, exp.Star): 254 tables = list(scope.selected_sources) 255 _add_except_columns(expression, tables, except_columns) 256 _add_replace_columns(expression, tables, replace_columns) 257 elif expression.is_star: 258 tables = [expression.table] 259 _add_except_columns(expression.this, tables, except_columns) 260 _add_replace_columns(expression.this, tables, replace_columns) 261 else: 262 new_selections.append(expression) 263 continue 264 265 for table in tables: 266 if table not in scope.sources: 267 raise OptimizeError(f"Unknown table: {table}") 268 columns = resolver.get_source_columns(table, only_visible=True) 269 270 if columns and "*" not in columns: 271 table_id = id(table) 272 for name in columns: 273 if name not in except_columns.get(table_id, set()): 274 alias_ = replace_columns.get(table_id, {}).get(name, name) 275 column = exp.column(name, table) 276 new_selections.append(alias(column, alias_) if alias_ != name else column) 277 else: 278 return 279 scope.expression.set("expressions", new_selections) 280 281 282def _add_except_columns(expression, tables, except_columns): 283 except_ = expression.args.get("except") 284 285 if not except_: 286 return 287 288 columns = {e.name for e in except_} 289 290 for table in tables: 291 except_columns[id(table)] = columns 292 293 294def _add_replace_columns(expression, tables, replace_columns): 295 replace = expression.args.get("replace") 296 297 if not replace: 298 return 299 300 columns = {e.this.name: e.alias for e in replace} 301 302 for table in tables: 303 replace_columns[id(table)] = columns 304 305 306def _qualify_outputs(scope): 307 """Ensure all output columns are aliased""" 308 new_selections = [] 309 310 for i, (selection, aliased_column) in enumerate( 311 itertools.zip_longest(scope.selects, scope.outer_column_list) 312 ): 313 if isinstance(selection, exp.Subquery): 314 if not selection.output_name: 315 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 316 elif not isinstance(selection, exp.Alias) and not selection.is_star: 317 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 318 alias_.set("this", selection) 319 selection = alias_ 320 321 if aliased_column: 322 selection.set("alias", exp.to_identifier(aliased_column)) 323 324 new_selections.append(selection) 325 326 scope.expression.set("expressions", new_selections) 327 328 329class Resolver: 330 """ 331 Helper for resolving columns. 332 333 This is a class so we can lazily load some things and easily share them across functions. 334 """ 335 336 def __init__(self, scope, schema): 337 self.scope = scope 338 self.schema = schema 339 self._source_columns = None 340 self._unambiguous_columns = None 341 self._all_columns = None 342 343 def get_table(self, column_name: str) -> t.Optional[str]: 344 """ 345 Get the table for a column name. 346 347 Args: 348 column_name: The column name to find the table for. 349 Returns: 350 The table name if it can be found/inferred. 351 """ 352 if self._unambiguous_columns is None: 353 self._unambiguous_columns = self._get_unambiguous_columns( 354 self._get_all_source_columns() 355 ) 356 357 table = self._unambiguous_columns.get(column_name) 358 359 if not table: 360 sources_without_schema = tuple( 361 source 362 for source, columns in self._get_all_source_columns().items() 363 if not columns or "*" in columns 364 ) 365 if len(sources_without_schema) == 1: 366 return sources_without_schema[0] 367 368 return table 369 370 @property 371 def all_columns(self): 372 """All available columns of all sources in this scope""" 373 if self._all_columns is None: 374 self._all_columns = { 375 column for columns in self._get_all_source_columns().values() for column in columns 376 } 377 return self._all_columns 378 379 def get_source_columns(self, name, only_visible=False): 380 """Resolve the source columns for a given source `name`""" 381 if name not in self.scope.sources: 382 raise OptimizeError(f"Unknown table: {name}") 383 384 source = self.scope.sources[name] 385 386 # If referencing a table, return the columns from the schema 387 if isinstance(source, exp.Table): 388 return self.schema.column_names(source, only_visible) 389 390 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 391 return source.expression.alias_column_names 392 393 # Otherwise, if referencing another scope, return that scope's named selects 394 return source.expression.named_selects 395 396 def _get_all_source_columns(self): 397 if self._source_columns is None: 398 self._source_columns = { 399 k: self.get_source_columns(k) 400 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 401 } 402 return self._source_columns 403 404 def _get_unambiguous_columns(self, source_columns): 405 """ 406 Find all the unambiguous columns in sources. 407 408 Args: 409 source_columns (dict): Mapping of names to source columns 410 Returns: 411 dict: Mapping of column name to source name 412 """ 413 if not source_columns: 414 return {} 415 416 source_columns = list(source_columns.items()) 417 418 first_table, first_columns = source_columns[0] 419 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 420 all_columns = set(unambiguous_columns) 421 422 for table, columns in source_columns[1:]: 423 unique = self._find_unique_columns(columns) 424 ambiguous = set(all_columns).intersection(unique) 425 all_columns.update(columns) 426 for column in ambiguous: 427 unambiguous_columns.pop(column, None) 428 for column in unique.difference(ambiguous): 429 unambiguous_columns[column] = table 430 431 return unambiguous_columns 432 433 @staticmethod 434 def _find_unique_columns(columns): 435 """ 436 Find the unique columns in a list of columns. 437 438 Example: 439 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 440 ['a', 'c'] 441 442 This is necessary because duplicate column names are ambiguous. 443 """ 444 counts = {} 445 for column in columns: 446 counts[column] = counts.get(column, 0) + 1 447 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 _expand_using(scope, resolver) 35 _qualify_columns(scope, resolver) 36 if not isinstance(scope.expression, exp.UDTF): 37 _expand_stars(scope, resolver) 38 _qualify_outputs(scope) 39 _expand_group_by(scope, resolver) 40 _expand_order_by(scope) 41 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
44def validate_qualify_columns(expression): 45 """Raise an `OptimizeError` if any columns aren't qualified""" 46 unqualified_columns = [] 47 for scope in traverse_scope(expression): 48 if isinstance(scope.expression, exp.Select): 49 unqualified_columns.extend(scope.unqualified_columns) 50 if scope.external_columns and not scope.is_correlated_subquery: 51 column = scope.external_columns[0] 52 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 53 54 if unqualified_columns: 55 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 56 return expression
Raise an OptimizeError
if any columns aren't qualified
class
Resolver:
330class Resolver: 331 """ 332 Helper for resolving columns. 333 334 This is a class so we can lazily load some things and easily share them across functions. 335 """ 336 337 def __init__(self, scope, schema): 338 self.scope = scope 339 self.schema = schema 340 self._source_columns = None 341 self._unambiguous_columns = None 342 self._all_columns = None 343 344 def get_table(self, column_name: str) -> t.Optional[str]: 345 """ 346 Get the table for a column name. 347 348 Args: 349 column_name: The column name to find the table for. 350 Returns: 351 The table name if it can be found/inferred. 352 """ 353 if self._unambiguous_columns is None: 354 self._unambiguous_columns = self._get_unambiguous_columns( 355 self._get_all_source_columns() 356 ) 357 358 table = self._unambiguous_columns.get(column_name) 359 360 if not table: 361 sources_without_schema = tuple( 362 source 363 for source, columns in self._get_all_source_columns().items() 364 if not columns or "*" in columns 365 ) 366 if len(sources_without_schema) == 1: 367 return sources_without_schema[0] 368 369 return table 370 371 @property 372 def all_columns(self): 373 """All available columns of all sources in this scope""" 374 if self._all_columns is None: 375 self._all_columns = { 376 column for columns in self._get_all_source_columns().values() for column in columns 377 } 378 return self._all_columns 379 380 def get_source_columns(self, name, only_visible=False): 381 """Resolve the source columns for a given source `name`""" 382 if name not in self.scope.sources: 383 raise OptimizeError(f"Unknown table: {name}") 384 385 source = self.scope.sources[name] 386 387 # If referencing a table, return the columns from the schema 388 if isinstance(source, exp.Table): 389 return self.schema.column_names(source, only_visible) 390 391 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 392 return source.expression.alias_column_names 393 394 # Otherwise, if referencing another scope, return that scope's named selects 395 return source.expression.named_selects 396 397 def _get_all_source_columns(self): 398 if self._source_columns is None: 399 self._source_columns = { 400 k: self.get_source_columns(k) 401 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 402 } 403 return self._source_columns 404 405 def _get_unambiguous_columns(self, source_columns): 406 """ 407 Find all the unambiguous columns in sources. 408 409 Args: 410 source_columns (dict): Mapping of names to source columns 411 Returns: 412 dict: Mapping of column name to source name 413 """ 414 if not source_columns: 415 return {} 416 417 source_columns = list(source_columns.items()) 418 419 first_table, first_columns = source_columns[0] 420 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 421 all_columns = set(unambiguous_columns) 422 423 for table, columns in source_columns[1:]: 424 unique = self._find_unique_columns(columns) 425 ambiguous = set(all_columns).intersection(unique) 426 all_columns.update(columns) 427 for column in ambiguous: 428 unambiguous_columns.pop(column, None) 429 for column in unique.difference(ambiguous): 430 unambiguous_columns[column] = table 431 432 return unambiguous_columns 433 434 @staticmethod 435 def _find_unique_columns(columns): 436 """ 437 Find the unique columns in a list of columns. 438 439 Example: 440 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 441 ['a', 'c'] 442 443 This is necessary because duplicate column names are ambiguous. 444 """ 445 counts = {} 446 for column in columns: 447 counts[column] = counts.get(column, 0) + 1 448 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
def
get_table(self, column_name: str) -> Optional[str]:
344 def get_table(self, column_name: str) -> t.Optional[str]: 345 """ 346 Get the table for a column name. 347 348 Args: 349 column_name: The column name to find the table for. 350 Returns: 351 The table name if it can be found/inferred. 352 """ 353 if self._unambiguous_columns is None: 354 self._unambiguous_columns = self._get_unambiguous_columns( 355 self._get_all_source_columns() 356 ) 357 358 table = self._unambiguous_columns.get(column_name) 359 360 if not table: 361 sources_without_schema = tuple( 362 source 363 for source, columns in self._get_all_source_columns().items() 364 if not columns or "*" in columns 365 ) 366 if len(sources_without_schema) == 1: 367 return sources_without_schema[0] 368 369 return table
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
380 def get_source_columns(self, name, only_visible=False): 381 """Resolve the source columns for a given source `name`""" 382 if name not in self.scope.sources: 383 raise OptimizeError(f"Unknown table: {name}") 384 385 source = self.scope.sources[name] 386 387 # If referencing a table, return the columns from the schema 388 if isinstance(source, exp.Table): 389 return self.schema.column_names(source, only_visible) 390 391 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 392 return source.expression.alias_column_names 393 394 # Otherwise, if referencing another scope, return that scope's named selects 395 return source.expression.named_selects
Resolve the source columns for a given source name