sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: Expression to qualify. 34 schema: Database schema. 35 expand_alias_refs: Whether or not to expand references to aliases. 36 infer_schema: Whether or not to infer the schema if missing. 37 38 Returns: 39 The qualified expression. 40 """ 41 schema = ensure_schema(schema) 42 infer_schema = schema.empty if infer_schema is None else infer_schema 43 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 44 45 for scope in traverse_scope(expression): 46 resolver = Resolver(scope, schema, infer_schema=infer_schema) 47 _pop_table_column_aliases(scope.ctes) 48 _pop_table_column_aliases(scope.derived_tables) 49 using_column_tables = _expand_using(scope, resolver) 50 51 if schema.empty and expand_alias_refs: 52 _expand_alias_refs(scope, resolver) 53 54 _qualify_columns(scope, resolver) 55 56 if not schema.empty and expand_alias_refs: 57 _expand_alias_refs(scope, resolver) 58 59 if not isinstance(scope.expression, exp.UDTF): 60 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 61 _qualify_outputs(scope) 62 63 _expand_group_by(scope) 64 _expand_order_by(scope, resolver) 65 66 return expression 67 68 69def validate_qualify_columns(expression: E) -> E: 70 """Raise an `OptimizeError` if any columns aren't qualified""" 71 unqualified_columns = [] 72 for scope in traverse_scope(expression): 73 if isinstance(scope.expression, exp.Select): 74 unqualified_columns.extend(scope.unqualified_columns) 75 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 76 column = scope.external_columns[0] 77 raise OptimizeError( 78 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 79 ) 80 81 if unqualified_columns: 82 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 83 return expression 84 85 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 87 """ 88 Remove table column aliases. 89 90 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 91 """ 92 for derived_table in derived_tables: 93 table_alias = derived_table.args.get("alias") 94 if table_alias: 95 table_alias.args.pop("columns", None) 96 97 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 99 joins = list(scope.find_all(exp.Join)) 100 names = {join.alias_or_name for join in joins} 101 ordered = [key for key in scope.selected_sources if key not in names] 102 103 # Mapping of automatically joined column names to an ordered set of source names (dict). 104 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 105 106 for join in joins: 107 using = join.args.get("using") 108 109 if not using: 110 continue 111 112 join_table = join.alias_or_name 113 114 columns = {} 115 116 for source_name in scope.selected_sources: 117 if source_name in ordered: 118 for column_name in resolver.get_source_columns(source_name): 119 if column_name not in columns: 120 columns[column_name] = source_name 121 122 source_table = ordered[-1] 123 ordered.append(join_table) 124 join_columns = resolver.get_source_columns(join_table) 125 conditions = [] 126 127 for identifier in using: 128 identifier = identifier.name 129 table = columns.get(identifier) 130 131 if not table or identifier not in join_columns: 132 if columns and join_columns: 133 raise OptimizeError(f"Cannot automatically join: {identifier}") 134 135 table = table or source_table 136 conditions.append( 137 exp.condition( 138 exp.EQ( 139 this=exp.column(identifier, table=table), 140 expression=exp.column(identifier, table=join_table), 141 ) 142 ) 143 ) 144 145 # Set all values in the dict to None, because we only care about the key ordering 146 tables = column_tables.setdefault(identifier, {}) 147 if table not in tables: 148 tables[table] = None 149 if join_table not in tables: 150 tables[join_table] = None 151 152 join.args.pop("using") 153 join.set("on", exp.and_(*conditions, copy=False)) 154 155 if column_tables: 156 for column in scope.columns: 157 if not column.table and column.name in column_tables: 158 tables = column_tables[column.name] 159 coalesce = [exp.column(column.name, table=table) for table in tables] 160 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 161 162 # Ensure selects keep their output name 163 if isinstance(column.parent, exp.Select): 164 replacement = alias(replacement, alias=column.name, copy=False) 165 166 scope.replace(column, replacement) 167 168 return column_tables 169 170 171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 172 expression = scope.expression 173 174 if not isinstance(expression, exp.Select): 175 return 176 177 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 178 179 def replace_columns( 180 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 181 ) -> None: 182 if not node: 183 return 184 185 for column, *_ in walk_in_scope(node): 186 if not isinstance(column, exp.Column): 187 continue 188 189 table = resolver.get_table(column.name) if resolve_table and not column.table else None 190 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 191 double_agg = ( 192 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 193 if alias_expr 194 else False 195 ) 196 197 if table and (not alias_expr or double_agg): 198 column.set("table", table) 199 elif not column.table and alias_expr and not double_agg: 200 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 201 if literal_index: 202 column.replace(exp.Literal.number(i)) 203 else: 204 column = column.replace(exp.paren(alias_expr)) 205 simplified = simplify_parens(column) 206 if simplified is not column: 207 column.replace(simplified) 208 209 for i, projection in enumerate(scope.expression.selects): 210 replace_columns(projection) 211 212 if isinstance(projection, exp.Alias): 213 alias_to_expression[projection.alias] = (projection.this, i + 1) 214 215 replace_columns(expression.args.get("where")) 216 replace_columns(expression.args.get("group"), literal_index=True) 217 replace_columns(expression.args.get("having"), resolve_table=True) 218 replace_columns(expression.args.get("qualify"), resolve_table=True) 219 scope.clear_cache() 220 221 222def _expand_group_by(scope: Scope) -> None: 223 expression = scope.expression 224 group = expression.args.get("group") 225 if not group: 226 return 227 228 group.set("expressions", _expand_positional_references(scope, group.expressions)) 229 expression.set("group", group) 230 231 232def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 233 order = scope.expression.args.get("order") 234 if not order: 235 return 236 237 ordereds = order.expressions 238 for ordered, new_expression in zip( 239 ordereds, 240 _expand_positional_references(scope, (o.this for o in ordereds)), 241 ): 242 for agg in ordered.find_all(exp.AggFunc): 243 for col in agg.find_all(exp.Column): 244 if not col.table: 245 col.set("table", resolver.get_table(col.name)) 246 247 ordered.set("this", new_expression) 248 249 if scope.expression.args.get("group"): 250 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 251 252 for ordered in ordereds: 253 ordered = ordered.this 254 255 ordered.replace( 256 exp.to_identifier(_select_by_pos(scope, ordered).alias) 257 if ordered.is_int 258 else selects.get(ordered, ordered) 259 ) 260 261 262def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: 263 new_nodes = [] 264 for node in expressions: 265 if node.is_int: 266 select = _select_by_pos(scope, t.cast(exp.Literal, node)).this 267 268 if isinstance(select, exp.Literal): 269 new_nodes.append(node) 270 else: 271 new_nodes.append(select.copy()) 272 scope.clear_cache() 273 else: 274 new_nodes.append(node) 275 276 return new_nodes 277 278 279def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 280 try: 281 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 282 except IndexError: 283 raise OptimizeError(f"Unknown output column: {node.name}") 284 285 286def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 287 """Disambiguate columns, ensuring each column specifies a source""" 288 for column in scope.columns: 289 column_table = column.table 290 column_name = column.name 291 292 if column_table and column_table in scope.sources: 293 source_columns = resolver.get_source_columns(column_table) 294 if source_columns and column_name not in source_columns and "*" not in source_columns: 295 raise OptimizeError(f"Unknown column: {column_name}") 296 297 if not column_table: 298 if scope.pivots and not column.find_ancestor(exp.Pivot): 299 # If the column is under the Pivot expression, we need to qualify it 300 # using the name of the pivoted source instead of the pivot's alias 301 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 302 continue 303 304 column_table = resolver.get_table(column_name) 305 306 # column_table can be a '' because bigquery unnest has no table alias 307 if column_table: 308 column.set("table", column_table) 309 elif column_table not in scope.sources and ( 310 not scope.parent or column_table not in scope.parent.sources 311 ): 312 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 313 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 314 315 root, *parts = column.parts 316 317 if root.name in scope.sources: 318 # struct is already qualified, but we still need to change the AST representation 319 column_table = root 320 root, *parts = parts 321 else: 322 column_table = resolver.get_table(root.name) 323 324 if column_table: 325 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 326 327 for pivot in scope.pivots: 328 for column in pivot.find_all(exp.Column): 329 if not column.table and column.name in resolver.all_columns: 330 column_table = resolver.get_table(column.name) 331 if column_table: 332 column.set("table", column_table) 333 334 335def _expand_stars( 336 scope: Scope, 337 resolver: Resolver, 338 using_column_tables: t.Dict[str, t.Any], 339 pseudocolumns: t.Set[str], 340) -> None: 341 """Expand stars to lists of column selections""" 342 343 new_selections = [] 344 except_columns: t.Dict[int, t.Set[str]] = {} 345 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 346 coalesced_columns = set() 347 348 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 349 pivot_columns = None 350 pivot_output_columns = None 351 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 352 353 has_pivoted_source = pivot and not pivot.args.get("unpivot") 354 if pivot and has_pivoted_source: 355 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 356 357 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 358 if not pivot_output_columns: 359 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 360 361 for expression in scope.expression.selects: 362 if isinstance(expression, exp.Star): 363 tables = list(scope.selected_sources) 364 _add_except_columns(expression, tables, except_columns) 365 _add_replace_columns(expression, tables, replace_columns) 366 elif expression.is_star: 367 tables = [expression.table] 368 _add_except_columns(expression.this, tables, except_columns) 369 _add_replace_columns(expression.this, tables, replace_columns) 370 else: 371 new_selections.append(expression) 372 continue 373 374 for table in tables: 375 if table not in scope.sources: 376 raise OptimizeError(f"Unknown table: {table}") 377 378 columns = resolver.get_source_columns(table, only_visible=True) 379 380 if pseudocolumns: 381 columns = [name for name in columns if name.upper() not in pseudocolumns] 382 383 if columns and "*" not in columns: 384 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 385 implicit_columns = [col for col in columns if col not in pivot_columns] 386 new_selections.extend( 387 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 388 for name in implicit_columns + pivot_output_columns 389 ) 390 continue 391 392 table_id = id(table) 393 for name in columns: 394 if name in using_column_tables and table in using_column_tables[name]: 395 if name in coalesced_columns: 396 continue 397 398 coalesced_columns.add(name) 399 tables = using_column_tables[name] 400 coalesce = [exp.column(name, table=table) for table in tables] 401 402 new_selections.append( 403 alias( 404 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 405 alias=name, 406 copy=False, 407 ) 408 ) 409 elif name not in except_columns.get(table_id, set()): 410 alias_ = replace_columns.get(table_id, {}).get(name, name) 411 column = exp.column(name, table=table) 412 new_selections.append( 413 alias(column, alias_, copy=False) if alias_ != name else column 414 ) 415 else: 416 return 417 418 # Ensures we don't overwrite the initial selections with an empty list 419 if new_selections: 420 scope.expression.set("expressions", new_selections) 421 422 423def _add_except_columns( 424 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 425) -> None: 426 except_ = expression.args.get("except") 427 428 if not except_: 429 return 430 431 columns = {e.name for e in except_} 432 433 for table in tables: 434 except_columns[id(table)] = columns 435 436 437def _add_replace_columns( 438 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 439) -> None: 440 replace = expression.args.get("replace") 441 442 if not replace: 443 return 444 445 columns = {e.this.name: e.alias for e in replace} 446 447 for table in tables: 448 replace_columns[id(table)] = columns 449 450 451def _qualify_outputs(scope: Scope) -> None: 452 """Ensure all output columns are aliased""" 453 new_selections = [] 454 455 for i, (selection, aliased_column) in enumerate( 456 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 457 ): 458 if isinstance(selection, exp.Subquery): 459 if not selection.output_name: 460 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 461 elif not isinstance(selection, exp.Alias) and not selection.is_star: 462 selection = alias( 463 selection, 464 alias=selection.output_name or f"_col_{i}", 465 ) 466 if aliased_column: 467 selection.set("alias", exp.to_identifier(aliased_column)) 468 469 new_selections.append(selection) 470 471 scope.expression.set("expressions", new_selections) 472 473 474def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 475 """Makes sure all identifiers that need to be quoted are quoted.""" 476 return expression.transform( 477 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 478 ) 479 480 481class Resolver: 482 """ 483 Helper for resolving columns. 484 485 This is a class so we can lazily load some things and easily share them across functions. 486 """ 487 488 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 489 self.scope = scope 490 self.schema = schema 491 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 492 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 493 self._all_columns: t.Optional[t.Set[str]] = None 494 self._infer_schema = infer_schema 495 496 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 497 """ 498 Get the table for a column name. 499 500 Args: 501 column_name: The column name to find the table for. 502 Returns: 503 The table name if it can be found/inferred. 504 """ 505 if self._unambiguous_columns is None: 506 self._unambiguous_columns = self._get_unambiguous_columns( 507 self._get_all_source_columns() 508 ) 509 510 table_name = self._unambiguous_columns.get(column_name) 511 512 if not table_name and self._infer_schema: 513 sources_without_schema = tuple( 514 source 515 for source, columns in self._get_all_source_columns().items() 516 if not columns or "*" in columns 517 ) 518 if len(sources_without_schema) == 1: 519 table_name = sources_without_schema[0] 520 521 if table_name not in self.scope.selected_sources: 522 return exp.to_identifier(table_name) 523 524 node, _ = self.scope.selected_sources.get(table_name) 525 526 if isinstance(node, exp.Subqueryable): 527 while node and node.alias != table_name: 528 node = node.parent 529 530 node_alias = node.args.get("alias") 531 if node_alias: 532 return exp.to_identifier(node_alias.this) 533 534 return exp.to_identifier(table_name) 535 536 @property 537 def all_columns(self) -> t.Set[str]: 538 """All available columns of all sources in this scope""" 539 if self._all_columns is None: 540 self._all_columns = { 541 column for columns in self._get_all_source_columns().values() for column in columns 542 } 543 return self._all_columns 544 545 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 546 """Resolve the source columns for a given source `name`.""" 547 if name not in self.scope.sources: 548 raise OptimizeError(f"Unknown table: {name}") 549 550 source = self.scope.sources[name] 551 552 if isinstance(source, exp.Table): 553 columns = self.schema.column_names(source, only_visible) 554 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 555 columns = source.expression.alias_column_names 556 else: 557 columns = source.expression.named_selects 558 559 node, _ = self.scope.selected_sources.get(name) or (None, None) 560 if isinstance(node, Scope): 561 column_aliases = node.expression.alias_column_names 562 elif isinstance(node, exp.Expression): 563 column_aliases = node.alias_column_names 564 else: 565 column_aliases = [] 566 567 # If the source's columns are aliased, their aliases shadow the corresponding column names 568 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 569 570 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 571 if self._source_columns is None: 572 self._source_columns = { 573 source_name: self.get_source_columns(source_name) 574 for source_name, source in itertools.chain( 575 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 576 ) 577 } 578 return self._source_columns 579 580 def _get_unambiguous_columns( 581 self, source_columns: t.Dict[str, t.List[str]] 582 ) -> t.Dict[str, str]: 583 """ 584 Find all the unambiguous columns in sources. 585 586 Args: 587 source_columns: Mapping of names to source columns. 588 589 Returns: 590 Mapping of column name to source name. 591 """ 592 if not source_columns: 593 return {} 594 595 source_columns_pairs = list(source_columns.items()) 596 597 first_table, first_columns = source_columns_pairs[0] 598 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 599 all_columns = set(unambiguous_columns) 600 601 for table, columns in source_columns_pairs[1:]: 602 unique = self._find_unique_columns(columns) 603 ambiguous = set(all_columns).intersection(unique) 604 all_columns.update(columns) 605 606 for column in ambiguous: 607 unambiguous_columns.pop(column, None) 608 for column in unique.difference(ambiguous): 609 unambiguous_columns[column] = table 610 611 return unambiguous_columns 612 613 @staticmethod 614 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 615 """ 616 Find the unique columns in a list of columns. 617 618 Example: 619 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 620 ['a', 'c'] 621 622 This is necessary because duplicate column names are ambiguous. 623 """ 624 counts: t.Dict[str, int] = {} 625 for column in columns: 626 counts[column] = counts.get(column, 0) + 1 627 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 infer_schema: Whether or not to infer the schema if missing. 38 39 Returns: 40 The qualified expression. 41 """ 42 schema = ensure_schema(schema) 43 infer_schema = schema.empty if infer_schema is None else infer_schema 44 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 45 46 for scope in traverse_scope(expression): 47 resolver = Resolver(scope, schema, infer_schema=infer_schema) 48 _pop_table_column_aliases(scope.ctes) 49 _pop_table_column_aliases(scope.derived_tables) 50 using_column_tables = _expand_using(scope, resolver) 51 52 if schema.empty and expand_alias_refs: 53 _expand_alias_refs(scope, resolver) 54 55 _qualify_columns(scope, resolver) 56 57 if not schema.empty and expand_alias_refs: 58 _expand_alias_refs(scope, resolver) 59 60 if not isinstance(scope.expression, exp.UDTF): 61 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 62 _qualify_outputs(scope) 63 64 _expand_group_by(scope) 65 _expand_order_by(scope, resolver) 66 67 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
def
validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E: 71 """Raise an `OptimizeError` if any columns aren't qualified""" 72 unqualified_columns = [] 73 for scope in traverse_scope(expression): 74 if isinstance(scope.expression, exp.Select): 75 unqualified_columns.extend(scope.unqualified_columns) 76 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 77 column = scope.external_columns[0] 78 raise OptimizeError( 79 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 80 ) 81 82 if unqualified_columns: 83 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 84 return expression
Raise an OptimizeError
if any columns aren't qualified
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
475def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 476 """Makes sure all identifiers that need to be quoted are quoted.""" 477 return expression.transform( 478 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 479 )
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
482class Resolver: 483 """ 484 Helper for resolving columns. 485 486 This is a class so we can lazily load some things and easily share them across functions. 487 """ 488 489 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 490 self.scope = scope 491 self.schema = schema 492 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 493 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 494 self._all_columns: t.Optional[t.Set[str]] = None 495 self._infer_schema = infer_schema 496 497 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 498 """ 499 Get the table for a column name. 500 501 Args: 502 column_name: The column name to find the table for. 503 Returns: 504 The table name if it can be found/inferred. 505 """ 506 if self._unambiguous_columns is None: 507 self._unambiguous_columns = self._get_unambiguous_columns( 508 self._get_all_source_columns() 509 ) 510 511 table_name = self._unambiguous_columns.get(column_name) 512 513 if not table_name and self._infer_schema: 514 sources_without_schema = tuple( 515 source 516 for source, columns in self._get_all_source_columns().items() 517 if not columns or "*" in columns 518 ) 519 if len(sources_without_schema) == 1: 520 table_name = sources_without_schema[0] 521 522 if table_name not in self.scope.selected_sources: 523 return exp.to_identifier(table_name) 524 525 node, _ = self.scope.selected_sources.get(table_name) 526 527 if isinstance(node, exp.Subqueryable): 528 while node and node.alias != table_name: 529 node = node.parent 530 531 node_alias = node.args.get("alias") 532 if node_alias: 533 return exp.to_identifier(node_alias.this) 534 535 return exp.to_identifier(table_name) 536 537 @property 538 def all_columns(self) -> t.Set[str]: 539 """All available columns of all sources in this scope""" 540 if self._all_columns is None: 541 self._all_columns = { 542 column for columns in self._get_all_source_columns().values() for column in columns 543 } 544 return self._all_columns 545 546 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 547 """Resolve the source columns for a given source `name`.""" 548 if name not in self.scope.sources: 549 raise OptimizeError(f"Unknown table: {name}") 550 551 source = self.scope.sources[name] 552 553 if isinstance(source, exp.Table): 554 columns = self.schema.column_names(source, only_visible) 555 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 556 columns = source.expression.alias_column_names 557 else: 558 columns = source.expression.named_selects 559 560 node, _ = self.scope.selected_sources.get(name) or (None, None) 561 if isinstance(node, Scope): 562 column_aliases = node.expression.alias_column_names 563 elif isinstance(node, exp.Expression): 564 column_aliases = node.alias_column_names 565 else: 566 column_aliases = [] 567 568 # If the source's columns are aliased, their aliases shadow the corresponding column names 569 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 570 571 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 572 if self._source_columns is None: 573 self._source_columns = { 574 source_name: self.get_source_columns(source_name) 575 for source_name, source in itertools.chain( 576 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 577 ) 578 } 579 return self._source_columns 580 581 def _get_unambiguous_columns( 582 self, source_columns: t.Dict[str, t.List[str]] 583 ) -> t.Dict[str, str]: 584 """ 585 Find all the unambiguous columns in sources. 586 587 Args: 588 source_columns: Mapping of names to source columns. 589 590 Returns: 591 Mapping of column name to source name. 592 """ 593 if not source_columns: 594 return {} 595 596 source_columns_pairs = list(source_columns.items()) 597 598 first_table, first_columns = source_columns_pairs[0] 599 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 600 all_columns = set(unambiguous_columns) 601 602 for table, columns in source_columns_pairs[1:]: 603 unique = self._find_unique_columns(columns) 604 ambiguous = set(all_columns).intersection(unique) 605 all_columns.update(columns) 606 607 for column in ambiguous: 608 unambiguous_columns.pop(column, None) 609 for column in unique.difference(ambiguous): 610 unambiguous_columns[column] = table 611 612 return unambiguous_columns 613 614 @staticmethod 615 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 616 """ 617 Find the unique columns in a list of columns. 618 619 Example: 620 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 621 ['a', 'c'] 622 623 This is necessary because duplicate column names are ambiguous. 624 """ 625 counts: t.Dict[str, int] = {} 626 for column in columns: 627 counts[column] = counts.get(column, 0) + 1 628 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
489 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 490 self.scope = scope 491 self.schema = schema 492 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 493 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 494 self._all_columns: t.Optional[t.Set[str]] = None 495 self._infer_schema = infer_schema
497 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 498 """ 499 Get the table for a column name. 500 501 Args: 502 column_name: The column name to find the table for. 503 Returns: 504 The table name if it can be found/inferred. 505 """ 506 if self._unambiguous_columns is None: 507 self._unambiguous_columns = self._get_unambiguous_columns( 508 self._get_all_source_columns() 509 ) 510 511 table_name = self._unambiguous_columns.get(column_name) 512 513 if not table_name and self._infer_schema: 514 sources_without_schema = tuple( 515 source 516 for source, columns in self._get_all_source_columns().items() 517 if not columns or "*" in columns 518 ) 519 if len(sources_without_schema) == 1: 520 table_name = sources_without_schema[0] 521 522 if table_name not in self.scope.selected_sources: 523 return exp.to_identifier(table_name) 524 525 node, _ = self.scope.selected_sources.get(table_name) 526 527 if isinstance(node, exp.Subqueryable): 528 while node and node.alias != table_name: 529 node = node.parent 530 531 node_alias = node.args.get("alias") 532 if node_alias: 533 return exp.to_identifier(node_alias.this) 534 535 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
546 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 547 """Resolve the source columns for a given source `name`.""" 548 if name not in self.scope.sources: 549 raise OptimizeError(f"Unknown table: {name}") 550 551 source = self.scope.sources[name] 552 553 if isinstance(source, exp.Table): 554 columns = self.schema.column_names(source, only_visible) 555 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 556 columns = source.expression.alias_column_names 557 else: 558 columns = source.expression.named_selects 559 560 node, _ = self.scope.selected_sources.get(name) or (None, None) 561 if isinstance(node, Scope): 562 column_aliases = node.expression.alias_column_names 563 elif isinstance(node, exp.Expression): 564 column_aliases = node.alias_column_names 565 else: 566 column_aliases = [] 567 568 # If the source's columns are aliased, their aliases shadow the corresponding column names 569 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.