sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = _Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 _expand_using(scope, resolver) 34 _expand_group_by(scope, resolver) 35 _qualify_columns(scope, resolver) 36 _expand_order_by(scope) 37 if not isinstance(scope.expression, exp.UDTF): 38 _expand_stars(scope, resolver) 39 _qualify_outputs(scope) 40 41 return expression 42 43 44def validate_qualify_columns(expression): 45 """Raise an `OptimizeError` if any columns aren't qualified""" 46 unqualified_columns = [] 47 for scope in traverse_scope(expression): 48 if isinstance(scope.expression, exp.Select): 49 unqualified_columns.extend(scope.unqualified_columns) 50 if scope.external_columns and not scope.is_correlated_subquery: 51 raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") 52 53 if unqualified_columns: 54 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 55 return expression 56 57 58def _pop_table_column_aliases(derived_tables): 59 """ 60 Remove table column aliases. 61 62 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 63 """ 64 for derived_table in derived_tables: 65 if isinstance(derived_table.unnest(), exp.UDTF): 66 continue 67 table_alias = derived_table.args.get("alias") 68 if table_alias: 69 table_alias.args.pop("columns", None) 70 71 72def _expand_using(scope, resolver): 73 joins = list(scope.expression.find_all(exp.Join)) 74 names = {join.this.alias for join in joins} 75 ordered = [key for key in scope.selected_sources if key not in names] 76 77 # Mapping of automatically joined column names to source names 78 column_tables = {} 79 80 for join in joins: 81 using = join.args.get("using") 82 83 if not using: 84 continue 85 86 join_table = join.this.alias_or_name 87 88 columns = {} 89 90 for k in scope.selected_sources: 91 if k in ordered: 92 for column in resolver.get_source_columns(k): 93 if column not in columns: 94 columns[column] = k 95 96 ordered.append(join_table) 97 join_columns = resolver.get_source_columns(join_table) 98 conditions = [] 99 100 for identifier in using: 101 identifier = identifier.name 102 table = columns.get(identifier) 103 104 if not table or identifier not in join_columns: 105 raise OptimizeError(f"Cannot automatically join: {identifier}") 106 107 conditions.append( 108 exp.condition( 109 exp.EQ( 110 this=exp.column(identifier, table=table), 111 expression=exp.column(identifier, table=join_table), 112 ) 113 ) 114 ) 115 116 tables = column_tables.setdefault(identifier, []) 117 if table not in tables: 118 tables.append(table) 119 if join_table not in tables: 120 tables.append(join_table) 121 122 join.args.pop("using") 123 join.set("on", exp.and_(*conditions)) 124 125 if column_tables: 126 for column in scope.columns: 127 if not column.table and column.name in column_tables: 128 tables = column_tables[column.name] 129 coalesce = [exp.column(column.name, table=table) for table in tables] 130 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 131 132 # Ensure selects keep their output name 133 if isinstance(column.parent, exp.Select): 134 replacement = exp.alias_(replacement, alias=column.name) 135 136 scope.replace(column, replacement) 137 138 139def _expand_group_by(scope, resolver): 140 group = scope.expression.args.get("group") 141 if not group: 142 return 143 144 # Replace references to select aliases 145 def transform(node, *_): 146 if isinstance(node, exp.Column) and not node.table: 147 table = resolver.get_table(node.name) 148 149 # Source columns get priority over select aliases 150 if table: 151 node.set("table", exp.to_identifier(table)) 152 return node 153 154 selects = {s.alias_or_name: s for s in scope.selects} 155 156 select = selects.get(node.name) 157 if select: 158 scope.clear_cache() 159 if isinstance(select, exp.Alias): 160 select = select.this 161 return select.copy() 162 163 return node 164 165 group.transform(transform, copy=False) 166 group.set("expressions", _expand_positional_references(scope, group.expressions)) 167 scope.expression.set("group", group) 168 169 170def _expand_order_by(scope): 171 order = scope.expression.args.get("order") 172 if not order: 173 return 174 175 ordereds = order.expressions 176 for ordered, new_expression in zip( 177 ordereds, 178 _expand_positional_references(scope, (o.this for o in ordereds)), 179 ): 180 ordered.set("this", new_expression) 181 182 183def _expand_positional_references(scope, expressions): 184 new_nodes = [] 185 for node in expressions: 186 if node.is_int: 187 try: 188 select = scope.selects[int(node.name) - 1] 189 except IndexError: 190 raise OptimizeError(f"Unknown output column: {node.name}") 191 if isinstance(select, exp.Alias): 192 select = select.this 193 new_nodes.append(select.copy()) 194 scope.clear_cache() 195 else: 196 new_nodes.append(node) 197 198 return new_nodes 199 200 201def _qualify_columns(scope, resolver): 202 """Disambiguate columns, ensuring each column specifies a source""" 203 for column in scope.columns: 204 column_table = column.table 205 column_name = column.name 206 207 if column_table and column_table in scope.sources: 208 source_columns = resolver.get_source_columns(column_table) 209 if source_columns and column_name not in source_columns: 210 raise OptimizeError(f"Unknown column: {column_name}") 211 212 if not column_table: 213 column_table = resolver.get_table(column_name) 214 215 # column_table can be a '' because bigquery unnest has no table alias 216 if column_table: 217 column.set("table", exp.to_identifier(column_table)) 218 219 columns_missing_from_scope = [] 220 # Determine whether each reference in the order by clause is to a column or an alias. 221 for ordered in scope.find_all(exp.Ordered): 222 for column in ordered.find_all(exp.Column): 223 if ( 224 not column.table 225 and column.parent is not ordered 226 and column.name in resolver.all_columns 227 ): 228 columns_missing_from_scope.append(column) 229 230 # Determine whether each reference in the having clause is to a column or an alias. 231 for having in scope.find_all(exp.Having): 232 for column in having.find_all(exp.Column): 233 if ( 234 not column.table 235 and column.find_ancestor(exp.AggFunc) 236 and column.name in resolver.all_columns 237 ): 238 columns_missing_from_scope.append(column) 239 240 for column in columns_missing_from_scope: 241 column_table = resolver.get_table(column.name) 242 243 if column_table: 244 column.set("table", exp.to_identifier(column_table)) 245 246 247def _expand_stars(scope, resolver): 248 """Expand stars to lists of column selections""" 249 250 new_selections = [] 251 except_columns = {} 252 replace_columns = {} 253 254 for expression in scope.selects: 255 if isinstance(expression, exp.Star): 256 tables = list(scope.selected_sources) 257 _add_except_columns(expression, tables, except_columns) 258 _add_replace_columns(expression, tables, replace_columns) 259 elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): 260 tables = [expression.table] 261 _add_except_columns(expression.this, tables, except_columns) 262 _add_replace_columns(expression.this, tables, replace_columns) 263 else: 264 new_selections.append(expression) 265 continue 266 267 for table in tables: 268 if table not in scope.sources: 269 raise OptimizeError(f"Unknown table: {table}") 270 columns = resolver.get_source_columns(table, only_visible=True) 271 if not columns: 272 raise OptimizeError( 273 f"Table has no schema/columns. Cannot expand star for table: {table}." 274 ) 275 table_id = id(table) 276 for name in columns: 277 if name not in except_columns.get(table_id, set()): 278 alias_ = replace_columns.get(table_id, {}).get(name, name) 279 column = exp.column(name, table) 280 new_selections.append(alias(column, alias_) if alias_ != name else column) 281 282 scope.expression.set("expressions", new_selections) 283 284 285def _add_except_columns(expression, tables, except_columns): 286 except_ = expression.args.get("except") 287 288 if not except_: 289 return 290 291 columns = {e.name for e in except_} 292 293 for table in tables: 294 except_columns[id(table)] = columns 295 296 297def _add_replace_columns(expression, tables, replace_columns): 298 replace = expression.args.get("replace") 299 300 if not replace: 301 return 302 303 columns = {e.this.name: e.alias for e in replace} 304 305 for table in tables: 306 replace_columns[id(table)] = columns 307 308 309def _qualify_outputs(scope): 310 """Ensure all output columns are aliased""" 311 new_selections = [] 312 313 for i, (selection, aliased_column) in enumerate( 314 itertools.zip_longest(scope.selects, scope.outer_column_list) 315 ): 316 if isinstance(selection, exp.Subquery): 317 if not selection.output_name: 318 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 319 elif not isinstance(selection, exp.Alias): 320 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 321 alias_.set("this", selection) 322 selection = alias_ 323 324 if aliased_column: 325 selection.set("alias", exp.to_identifier(aliased_column)) 326 327 new_selections.append(selection) 328 329 scope.expression.set("expressions", new_selections) 330 331 332class _Resolver: 333 """ 334 Helper for resolving columns. 335 336 This is a class so we can lazily load some things and easily share them across functions. 337 """ 338 339 def __init__(self, scope, schema): 340 self.scope = scope 341 self.schema = schema 342 self._source_columns = None 343 self._unambiguous_columns = None 344 self._all_columns = None 345 346 def get_table(self, column_name: str) -> t.Optional[str]: 347 """ 348 Get the table for a column name. 349 350 Args: 351 column_name: The column name to find the table for. 352 Returns: 353 The table name if it can be found/inferred. 354 """ 355 if self._unambiguous_columns is None: 356 self._unambiguous_columns = self._get_unambiguous_columns( 357 self._get_all_source_columns() 358 ) 359 360 table = self._unambiguous_columns.get(column_name) 361 362 if not table: 363 sources_without_schema = tuple( 364 source for source, columns in self._get_all_source_columns().items() if not columns 365 ) 366 if len(sources_without_schema) == 1: 367 return sources_without_schema[0] 368 369 return table 370 371 @property 372 def all_columns(self): 373 """All available columns of all sources in this scope""" 374 if self._all_columns is None: 375 self._all_columns = { 376 column for columns in self._get_all_source_columns().values() for column in columns 377 } 378 return self._all_columns 379 380 def get_source_columns(self, name, only_visible=False): 381 """Resolve the source columns for a given source `name`""" 382 if name not in self.scope.sources: 383 raise OptimizeError(f"Unknown table: {name}") 384 385 source = self.scope.sources[name] 386 387 # If referencing a table, return the columns from the schema 388 if isinstance(source, exp.Table): 389 return self.schema.column_names(source, only_visible) 390 391 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 392 return source.expression.alias_column_names 393 394 # Otherwise, if referencing another scope, return that scope's named selects 395 return source.expression.named_selects 396 397 def _get_all_source_columns(self): 398 if self._source_columns is None: 399 self._source_columns = { 400 k: self.get_source_columns(k) for k in self.scope.selected_sources 401 } 402 return self._source_columns 403 404 def _get_unambiguous_columns(self, source_columns): 405 """ 406 Find all the unambiguous columns in sources. 407 408 Args: 409 source_columns (dict): Mapping of names to source columns 410 Returns: 411 dict: Mapping of column name to source name 412 """ 413 if not source_columns: 414 return {} 415 416 source_columns = list(source_columns.items()) 417 418 first_table, first_columns = source_columns[0] 419 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 420 all_columns = set(unambiguous_columns) 421 422 for table, columns in source_columns[1:]: 423 unique = self._find_unique_columns(columns) 424 ambiguous = set(all_columns).intersection(unique) 425 all_columns.update(columns) 426 for column in ambiguous: 427 unambiguous_columns.pop(column, None) 428 for column in unique.difference(ambiguous): 429 unambiguous_columns[column] = table 430 431 return unambiguous_columns 432 433 @staticmethod 434 def _find_unique_columns(columns): 435 """ 436 Find the unique columns in a list of columns. 437 438 Example: 439 >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) 440 ['a', 'c'] 441 442 This is necessary because duplicate column names are ambiguous. 443 """ 444 counts = {} 445 for column in columns: 446 counts[column] = counts.get(column, 0) + 1 447 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = _Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 _expand_using(scope, resolver) 35 _expand_group_by(scope, resolver) 36 _qualify_columns(scope, resolver) 37 _expand_order_by(scope) 38 if not isinstance(scope.expression, exp.UDTF): 39 _expand_stars(scope, resolver) 40 _qualify_outputs(scope) 41 42 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
45def validate_qualify_columns(expression): 46 """Raise an `OptimizeError` if any columns aren't qualified""" 47 unqualified_columns = [] 48 for scope in traverse_scope(expression): 49 if isinstance(scope.expression, exp.Select): 50 unqualified_columns.extend(scope.unqualified_columns) 51 if scope.external_columns and not scope.is_correlated_subquery: 52 raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") 53 54 if unqualified_columns: 55 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 56 return expression
Raise an OptimizeError
if any columns aren't qualified