sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = _Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 _expand_using(scope, resolver) 34 _qualify_columns(scope, resolver) 35 if not isinstance(scope.expression, exp.UDTF): 36 _expand_stars(scope, resolver) 37 _qualify_outputs(scope) 38 _expand_group_by(scope, resolver) 39 _expand_order_by(scope) 40 return expression 41 42 43def validate_qualify_columns(expression): 44 """Raise an `OptimizeError` if any columns aren't qualified""" 45 unqualified_columns = [] 46 for scope in traverse_scope(expression): 47 if isinstance(scope.expression, exp.Select): 48 unqualified_columns.extend(scope.unqualified_columns) 49 if scope.external_columns and not scope.is_correlated_subquery: 50 raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") 51 52 if unqualified_columns: 53 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 54 return expression 55 56 57def _pop_table_column_aliases(derived_tables): 58 """ 59 Remove table column aliases. 60 61 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 62 """ 63 for derived_table in derived_tables: 64 table_alias = derived_table.args.get("alias") 65 if table_alias: 66 table_alias.args.pop("columns", None) 67 68 69def _expand_using(scope, resolver): 70 joins = list(scope.expression.find_all(exp.Join)) 71 names = {join.this.alias for join in joins} 72 ordered = [key for key in scope.selected_sources if key not in names] 73 74 # Mapping of automatically joined column names to source names 75 column_tables = {} 76 77 for join in joins: 78 using = join.args.get("using") 79 80 if not using: 81 continue 82 83 join_table = join.this.alias_or_name 84 85 columns = {} 86 87 for k in scope.selected_sources: 88 if k in ordered: 89 for column in resolver.get_source_columns(k): 90 if column not in columns: 91 columns[column] = k 92 93 ordered.append(join_table) 94 join_columns = resolver.get_source_columns(join_table) 95 conditions = [] 96 97 for identifier in using: 98 identifier = identifier.name 99 table = columns.get(identifier) 100 101 if not table or identifier not in join_columns: 102 raise OptimizeError(f"Cannot automatically join: {identifier}") 103 104 conditions.append( 105 exp.condition( 106 exp.EQ( 107 this=exp.column(identifier, table=table), 108 expression=exp.column(identifier, table=join_table), 109 ) 110 ) 111 ) 112 113 tables = column_tables.setdefault(identifier, []) 114 if table not in tables: 115 tables.append(table) 116 if join_table not in tables: 117 tables.append(join_table) 118 119 join.args.pop("using") 120 join.set("on", exp.and_(*conditions)) 121 122 if column_tables: 123 for column in scope.columns: 124 if not column.table and column.name in column_tables: 125 tables = column_tables[column.name] 126 coalesce = [exp.column(column.name, table=table) for table in tables] 127 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 128 129 # Ensure selects keep their output name 130 if isinstance(column.parent, exp.Select): 131 replacement = exp.alias_(replacement, alias=column.name) 132 133 scope.replace(column, replacement) 134 135 136def _expand_group_by(scope, resolver): 137 group = scope.expression.args.get("group") 138 if not group: 139 return 140 141 # Replace references to select aliases 142 def transform(node, *_): 143 if isinstance(node, exp.Column) and not node.table: 144 table = resolver.get_table(node.name) 145 146 # Source columns get priority over select aliases 147 if table: 148 node.set("table", exp.to_identifier(table)) 149 return node 150 151 selects = {s.alias_or_name: s for s in scope.selects} 152 153 select = selects.get(node.name) 154 if select: 155 scope.clear_cache() 156 if isinstance(select, exp.Alias): 157 select = select.this 158 return select.copy() 159 160 return node 161 162 group.transform(transform, copy=False) 163 group.set("expressions", _expand_positional_references(scope, group.expressions)) 164 scope.expression.set("group", group) 165 166 167def _expand_order_by(scope): 168 order = scope.expression.args.get("order") 169 if not order: 170 return 171 172 ordereds = order.expressions 173 for ordered, new_expression in zip( 174 ordereds, 175 _expand_positional_references(scope, (o.this for o in ordereds)), 176 ): 177 ordered.set("this", new_expression) 178 179 180def _expand_positional_references(scope, expressions): 181 new_nodes = [] 182 for node in expressions: 183 if node.is_int: 184 try: 185 select = scope.selects[int(node.name) - 1] 186 except IndexError: 187 raise OptimizeError(f"Unknown output column: {node.name}") 188 if isinstance(select, exp.Alias): 189 select = select.this 190 new_nodes.append(select.copy()) 191 scope.clear_cache() 192 else: 193 new_nodes.append(node) 194 195 return new_nodes 196 197 198def _qualify_columns(scope, resolver): 199 """Disambiguate columns, ensuring each column specifies a source""" 200 for column in scope.columns: 201 column_table = column.table 202 column_name = column.name 203 204 if column_table and column_table in scope.sources: 205 source_columns = resolver.get_source_columns(column_table) 206 if source_columns and column_name not in source_columns: 207 raise OptimizeError(f"Unknown column: {column_name}") 208 209 if not column_table: 210 column_table = resolver.get_table(column_name) 211 212 # column_table can be a '' because bigquery unnest has no table alias 213 if column_table: 214 column.set("table", exp.to_identifier(column_table)) 215 216 columns_missing_from_scope = [] 217 # Determine whether each reference in the order by clause is to a column or an alias. 218 for ordered in scope.find_all(exp.Ordered): 219 for column in ordered.find_all(exp.Column): 220 if ( 221 not column.table 222 and column.parent is not ordered 223 and column.name in resolver.all_columns 224 ): 225 columns_missing_from_scope.append(column) 226 227 # Determine whether each reference in the having clause is to a column or an alias. 228 for having in scope.find_all(exp.Having): 229 for column in having.find_all(exp.Column): 230 if ( 231 not column.table 232 and column.find_ancestor(exp.AggFunc) 233 and column.name in resolver.all_columns 234 ): 235 columns_missing_from_scope.append(column) 236 237 for column in columns_missing_from_scope: 238 column_table = resolver.get_table(column.name) 239 240 if column_table: 241 column.set("table", exp.to_identifier(column_table)) 242 243 244def _expand_stars(scope, resolver): 245 """Expand stars to lists of column selections""" 246 247 new_selections = [] 248 except_columns = {} 249 replace_columns = {} 250 251 for expression in scope.selects: 252 if isinstance(expression, exp.Star): 253 tables = list(scope.selected_sources) 254 _add_except_columns(expression, tables, except_columns) 255 _add_replace_columns(expression, tables, replace_columns) 256 elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): 257 tables = [expression.table] 258 _add_except_columns(expression.this, tables, except_columns) 259 _add_replace_columns(expression.this, tables, replace_columns) 260 else: 261 new_selections.append(expression) 262 continue 263 264 for table in tables: 265 if table not in scope.sources: 266 raise OptimizeError(f"Unknown table: {table}") 267 columns = resolver.get_source_columns(table, only_visible=True) 268 if not columns: 269 raise OptimizeError( 270 f"Table has no schema/columns. Cannot expand star for table: {table}." 271 ) 272 table_id = id(table) 273 for name in columns: 274 if name not in except_columns.get(table_id, set()): 275 alias_ = replace_columns.get(table_id, {}).get(name, name) 276 column = exp.column(name, table) 277 new_selections.append(alias(column, alias_) if alias_ != name else column) 278 279 scope.expression.set("expressions", new_selections) 280 281 282def _add_except_columns(expression, tables, except_columns): 283 except_ = expression.args.get("except") 284 285 if not except_: 286 return 287 288 columns = {e.name for e in except_} 289 290 for table in tables: 291 except_columns[id(table)] = columns 292 293 294def _add_replace_columns(expression, tables, replace_columns): 295 replace = expression.args.get("replace") 296 297 if not replace: 298 return 299 300 columns = {e.this.name: e.alias for e in replace} 301 302 for table in tables: 303 replace_columns[id(table)] = columns 304 305 306def _qualify_outputs(scope): 307 """Ensure all output columns are aliased""" 308 new_selections = [] 309 310 for i, (selection, aliased_column) in enumerate( 311 itertools.zip_longest(scope.selects, scope.outer_column_list) 312 ): 313 if isinstance(selection, exp.Subquery): 314 if not selection.output_name: 315 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 316 elif not isinstance(selection, exp.Alias): 317 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 318 alias_.set("this", selection) 319 selection = alias_ 320 321 if aliased_column: 322 selection.set("alias", exp.to_identifier(aliased_column)) 323 324 new_selections.append(selection) 325 326 scope.expression.set("expressions", new_selections) 327 328 329class _Resolver: 330 """ 331 Helper for resolving columns. 332 333 This is a class so we can lazily load some things and easily share them across functions. 334 """ 335 336 def __init__(self, scope, schema): 337 self.scope = scope 338 self.schema = schema 339 self._source_columns = None 340 self._unambiguous_columns = None 341 self._all_columns = None 342 343 def get_table(self, column_name: str) -> t.Optional[str]: 344 """ 345 Get the table for a column name. 346 347 Args: 348 column_name: The column name to find the table for. 349 Returns: 350 The table name if it can be found/inferred. 351 """ 352 if self._unambiguous_columns is None: 353 self._unambiguous_columns = self._get_unambiguous_columns( 354 self._get_all_source_columns() 355 ) 356 357 table = self._unambiguous_columns.get(column_name) 358 359 if not table: 360 sources_without_schema = tuple( 361 source for source, columns in self._get_all_source_columns().items() if not columns 362 ) 363 if len(sources_without_schema) == 1: 364 return sources_without_schema[0] 365 366 return table 367 368 @property 369 def all_columns(self): 370 """All available columns of all sources in this scope""" 371 if self._all_columns is None: 372 self._all_columns = { 373 column for columns in self._get_all_source_columns().values() for column in columns 374 } 375 return self._all_columns 376 377 def get_source_columns(self, name, only_visible=False): 378 """Resolve the source columns for a given source `name`""" 379 if name not in self.scope.sources: 380 raise OptimizeError(f"Unknown table: {name}") 381 382 source = self.scope.sources[name] 383 384 # If referencing a table, return the columns from the schema 385 if isinstance(source, exp.Table): 386 return self.schema.column_names(source, only_visible) 387 388 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 389 return source.expression.alias_column_names 390 391 # Otherwise, if referencing another scope, return that scope's named selects 392 return source.expression.named_selects 393 394 def _get_all_source_columns(self): 395 if self._source_columns is None: 396 self._source_columns = { 397 k: self.get_source_columns(k) 398 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 399 } 400 return self._source_columns 401 402 def _get_unambiguous_columns(self, source_columns): 403 """ 404 Find all the unambiguous columns in sources. 405 406 Args: 407 source_columns (dict): Mapping of names to source columns 408 Returns: 409 dict: Mapping of column name to source name 410 """ 411 if not source_columns: 412 return {} 413 414 source_columns = list(source_columns.items()) 415 416 first_table, first_columns = source_columns[0] 417 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 418 all_columns = set(unambiguous_columns) 419 420 for table, columns in source_columns[1:]: 421 unique = self._find_unique_columns(columns) 422 ambiguous = set(all_columns).intersection(unique) 423 all_columns.update(columns) 424 for column in ambiguous: 425 unambiguous_columns.pop(column, None) 426 for column in unique.difference(ambiguous): 427 unambiguous_columns[column] = table 428 429 return unambiguous_columns 430 431 @staticmethod 432 def _find_unique_columns(columns): 433 """ 434 Find the unique columns in a list of columns. 435 436 Example: 437 >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) 438 ['a', 'c'] 439 440 This is necessary because duplicate column names are ambiguous. 441 """ 442 counts = {} 443 for column in columns: 444 counts[column] = counts.get(column, 0) + 1 445 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = _Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 _expand_using(scope, resolver) 35 _qualify_columns(scope, resolver) 36 if not isinstance(scope.expression, exp.UDTF): 37 _expand_stars(scope, resolver) 38 _qualify_outputs(scope) 39 _expand_group_by(scope, resolver) 40 _expand_order_by(scope) 41 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
44def validate_qualify_columns(expression): 45 """Raise an `OptimizeError` if any columns aren't qualified""" 46 unqualified_columns = [] 47 for scope in traverse_scope(expression): 48 if isinstance(scope.expression, exp.Select): 49 unqualified_columns.extend(scope.unqualified_columns) 50 if scope.external_columns and not scope.is_correlated_subquery: 51 raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") 52 53 if unqualified_columns: 54 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 55 return expression
Raise an OptimizeError
if any columns aren't qualified