sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15 ColumnMapping = t.Union[t.Dict, str, t.List] 16 17 18class Schema(abc.ABC): 19 """Abstract base class for database schemas""" 20 21 dialect: DialectType 22 23 @abc.abstractmethod 24 def add_table( 25 self, 26 table: exp.Table | str, 27 column_mapping: t.Optional[ColumnMapping] = None, 28 dialect: DialectType = None, 29 normalize: t.Optional[bool] = None, 30 match_depth: bool = True, 31 ) -> None: 32 """ 33 Register or update a table. Some implementing classes may require column information to also be provided. 34 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 35 36 Args: 37 table: the `Table` expression instance or string representing the table. 38 column_mapping: a column mapping that describes the structure of the table. 39 dialect: the SQL dialect that will be used to parse `table` if it's a string. 40 normalize: whether to normalize identifiers according to the dialect of interest. 41 match_depth: whether to enforce that the table must match the schema's depth or not. 42 """ 43 44 @abc.abstractmethod 45 def column_names( 46 self, 47 table: exp.Table | str, 48 only_visible: bool = False, 49 dialect: DialectType = None, 50 normalize: t.Optional[bool] = None, 51 ) -> t.Sequence[str]: 52 """ 53 Get the column names for a table. 54 55 Args: 56 table: the `Table` expression instance. 57 only_visible: whether to include invisible columns. 58 dialect: the SQL dialect that will be used to parse `table` if it's a string. 59 normalize: whether to normalize identifiers according to the dialect of interest. 60 61 Returns: 62 The sequence of column names. 63 """ 64 65 @abc.abstractmethod 66 def get_column_type( 67 self, 68 table: exp.Table | str, 69 column: exp.Column | str, 70 dialect: DialectType = None, 71 normalize: t.Optional[bool] = None, 72 ) -> exp.DataType: 73 """ 74 Get the `sqlglot.exp.DataType` type of a column in the schema. 75 76 Args: 77 table: the source table. 78 column: the target column. 79 dialect: the SQL dialect that will be used to parse `table` if it's a string. 80 normalize: whether to normalize identifiers according to the dialect of interest. 81 82 Returns: 83 The resulting column type. 84 """ 85 86 def has_column( 87 self, 88 table: exp.Table | str, 89 column: exp.Column | str, 90 dialect: DialectType = None, 91 normalize: t.Optional[bool] = None, 92 ) -> bool: 93 """ 94 Returns whether `column` appears in `table`'s schema. 95 96 Args: 97 table: the source table. 98 column: the target column. 99 dialect: the SQL dialect that will be used to parse `table` if it's a string. 100 normalize: whether to normalize identifiers according to the dialect of interest. 101 102 Returns: 103 True if the column appears in the schema, False otherwise. 104 """ 105 name = column if isinstance(column, str) else column.name 106 return name in self.column_names(table, dialect=dialect, normalize=normalize) 107 108 @property 109 @abc.abstractmethod 110 def supported_table_args(self) -> t.Tuple[str, ...]: 111 """ 112 Table arguments this schema support, e.g. `("this", "db", "catalog")` 113 """ 114 115 @property 116 def empty(self) -> bool: 117 """Returns whether the schema is empty.""" 118 return True 119 120 121class AbstractMappingSchema: 122 def __init__( 123 self, 124 mapping: t.Optional[t.Dict] = None, 125 ) -> None: 126 self.mapping = mapping or {} 127 self.mapping_trie = new_trie( 128 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 129 ) 130 self._supported_table_args: t.Tuple[str, ...] = tuple() 131 132 @property 133 def empty(self) -> bool: 134 return not self.mapping 135 136 def depth(self) -> int: 137 return dict_depth(self.mapping) 138 139 @property 140 def supported_table_args(self) -> t.Tuple[str, ...]: 141 if not self._supported_table_args and self.mapping: 142 depth = self.depth() 143 144 if not depth: # None 145 self._supported_table_args = tuple() 146 elif 1 <= depth <= 3: 147 self._supported_table_args = exp.TABLE_PARTS[:depth] 148 else: 149 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 150 151 return self._supported_table_args 152 153 def table_parts(self, table: exp.Table) -> t.List[str]: 154 if isinstance(table.this, exp.ReadCSV): 155 return [table.this.name] 156 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 157 158 def find( 159 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 160 ) -> t.Optional[t.Any]: 161 """ 162 Returns the schema of a given table. 163 164 Args: 165 table: the target table. 166 raise_on_missing: whether to raise in case the schema is not found. 167 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 168 169 Returns: 170 The schema of the target table. 171 """ 172 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 173 value, trie = in_trie(self.mapping_trie, parts) 174 175 if value == TrieResult.FAILED: 176 return None 177 178 if value == TrieResult.PREFIX: 179 possibilities = flatten_schema(trie) 180 181 if len(possibilities) == 1: 182 parts.extend(possibilities[0]) 183 else: 184 message = ", ".join(".".join(parts) for parts in possibilities) 185 if raise_on_missing: 186 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 187 return None 188 189 return self.nested_get(parts, raise_on_missing=raise_on_missing) 190 191 def nested_get( 192 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 193 ) -> t.Optional[t.Any]: 194 return nested_get( 195 d or self.mapping, 196 *zip(self.supported_table_args, reversed(parts)), 197 raise_on_missing=raise_on_missing, 198 ) 199 200 201class MappingSchema(AbstractMappingSchema, Schema): 202 """ 203 Schema based on a nested mapping. 204 205 Args: 206 schema: Mapping in one of the following forms: 207 1. {table: {col: type}} 208 2. {db: {table: {col: type}}} 209 3. {catalog: {db: {table: {col: type}}}} 210 4. None - Tables will be added later 211 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 212 are assumed to be visible. The nesting should mirror that of the schema: 213 1. {table: set(*cols)}} 214 2. {db: {table: set(*cols)}}} 215 3. {catalog: {db: {table: set(*cols)}}}} 216 dialect: The dialect to be used for custom type mappings & parsing string arguments. 217 normalize: Whether to normalize identifier names according to the given dialect or not. 218 """ 219 220 def __init__( 221 self, 222 schema: t.Optional[t.Dict] = None, 223 visible: t.Optional[t.Dict] = None, 224 dialect: DialectType = None, 225 normalize: bool = True, 226 ) -> None: 227 self.dialect = dialect 228 self.visible = {} if visible is None else visible 229 self.normalize = normalize 230 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 231 self._depth = 0 232 schema = {} if schema is None else schema 233 234 super().__init__(self._normalize(schema) if self.normalize else schema) 235 236 @classmethod 237 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 238 return MappingSchema( 239 schema=mapping_schema.mapping, 240 visible=mapping_schema.visible, 241 dialect=mapping_schema.dialect, 242 normalize=mapping_schema.normalize, 243 ) 244 245 def find( 246 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 247 ) -> t.Optional[t.Any]: 248 schema = super().find( 249 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 250 ) 251 if ensure_data_types and isinstance(schema, dict): 252 schema = { 253 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 254 for col, dtype in schema.items() 255 } 256 257 return schema 258 259 def copy(self, **kwargs) -> MappingSchema: 260 return MappingSchema( 261 **{ # type: ignore 262 "schema": self.mapping.copy(), 263 "visible": self.visible.copy(), 264 "dialect": self.dialect, 265 "normalize": self.normalize, 266 **kwargs, 267 } 268 ) 269 270 def add_table( 271 self, 272 table: exp.Table | str, 273 column_mapping: t.Optional[ColumnMapping] = None, 274 dialect: DialectType = None, 275 normalize: t.Optional[bool] = None, 276 match_depth: bool = True, 277 ) -> None: 278 """ 279 Register or update a table. Updates are only performed if a new column mapping is provided. 280 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 281 282 Args: 283 table: the `Table` expression instance or string representing the table. 284 column_mapping: a column mapping that describes the structure of the table. 285 dialect: the SQL dialect that will be used to parse `table` if it's a string. 286 normalize: whether to normalize identifiers according to the dialect of interest. 287 match_depth: whether to enforce that the table must match the schema's depth or not. 288 """ 289 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 290 291 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 292 raise SchemaError( 293 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 294 f"schema's nesting level: {self.depth()}." 295 ) 296 297 normalized_column_mapping = { 298 self._normalize_name(key, dialect=dialect, normalize=normalize): value 299 for key, value in ensure_column_mapping(column_mapping).items() 300 } 301 302 schema = self.find(normalized_table, raise_on_missing=False) 303 if schema and not normalized_column_mapping: 304 return 305 306 parts = self.table_parts(normalized_table) 307 308 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 309 new_trie([parts], self.mapping_trie) 310 311 def column_names( 312 self, 313 table: exp.Table | str, 314 only_visible: bool = False, 315 dialect: DialectType = None, 316 normalize: t.Optional[bool] = None, 317 ) -> t.List[str]: 318 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 319 320 schema = self.find(normalized_table) 321 if schema is None: 322 return [] 323 324 if not only_visible or not self.visible: 325 return list(schema) 326 327 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 328 return [col for col in schema if col in visible] 329 330 def get_column_type( 331 self, 332 table: exp.Table | str, 333 column: exp.Column | str, 334 dialect: DialectType = None, 335 normalize: t.Optional[bool] = None, 336 ) -> exp.DataType: 337 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 338 339 normalized_column_name = self._normalize_name( 340 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 341 ) 342 343 table_schema = self.find(normalized_table, raise_on_missing=False) 344 if table_schema: 345 column_type = table_schema.get(normalized_column_name) 346 347 if isinstance(column_type, exp.DataType): 348 return column_type 349 elif isinstance(column_type, str): 350 return self._to_data_type(column_type, dialect=dialect) 351 352 return exp.DataType.build("unknown") 353 354 def has_column( 355 self, 356 table: exp.Table | str, 357 column: exp.Column | str, 358 dialect: DialectType = None, 359 normalize: t.Optional[bool] = None, 360 ) -> bool: 361 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 362 363 normalized_column_name = self._normalize_name( 364 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 365 ) 366 367 table_schema = self.find(normalized_table, raise_on_missing=False) 368 return normalized_column_name in table_schema if table_schema else False 369 370 def _normalize(self, schema: t.Dict) -> t.Dict: 371 """ 372 Normalizes all identifiers in the schema. 373 374 Args: 375 schema: the schema to normalize. 376 377 Returns: 378 The normalized schema mapping. 379 """ 380 normalized_mapping: t.Dict = {} 381 flattened_schema = flatten_schema(schema) 382 error_msg = "Table {} must match the schema's nesting level: {}." 383 384 for keys in flattened_schema: 385 columns = nested_get(schema, *zip(keys, keys)) 386 387 if not isinstance(columns, dict): 388 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 389 if isinstance(first(columns.values()), dict): 390 raise SchemaError( 391 error_msg.format( 392 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 393 ), 394 ) 395 396 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 397 for column_name, column_type in columns.items(): 398 nested_set( 399 normalized_mapping, 400 normalized_keys + [self._normalize_name(column_name)], 401 column_type, 402 ) 403 404 return normalized_mapping 405 406 def _normalize_table( 407 self, 408 table: exp.Table | str, 409 dialect: DialectType = None, 410 normalize: t.Optional[bool] = None, 411 ) -> exp.Table: 412 dialect = dialect or self.dialect 413 normalize = self.normalize if normalize is None else normalize 414 415 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 416 417 if normalize: 418 for arg in exp.TABLE_PARTS: 419 value = normalized_table.args.get(arg) 420 if isinstance(value, exp.Identifier): 421 normalized_table.set( 422 arg, 423 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 424 ) 425 426 return normalized_table 427 428 def _normalize_name( 429 self, 430 name: str | exp.Identifier, 431 dialect: DialectType = None, 432 is_table: bool = False, 433 normalize: t.Optional[bool] = None, 434 ) -> str: 435 return normalize_name( 436 name, 437 dialect=dialect or self.dialect, 438 is_table=is_table, 439 normalize=self.normalize if normalize is None else normalize, 440 ).name 441 442 def depth(self) -> int: 443 if not self.empty and not self._depth: 444 # The columns themselves are a mapping, but we don't want to include those 445 self._depth = super().depth() - 1 446 return self._depth 447 448 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 449 """ 450 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 451 452 Args: 453 schema_type: the type we want to convert. 454 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 455 456 Returns: 457 The resulting expression type. 458 """ 459 if schema_type not in self._type_mapping_cache: 460 dialect = dialect or self.dialect 461 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 462 463 try: 464 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 465 self._type_mapping_cache[schema_type] = expression 466 except AttributeError: 467 in_dialect = f" in dialect {dialect}" if dialect else "" 468 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 469 470 return self._type_mapping_cache[schema_type] 471 472 473def normalize_name( 474 identifier: str | exp.Identifier, 475 dialect: DialectType = None, 476 is_table: bool = False, 477 normalize: t.Optional[bool] = True, 478) -> exp.Identifier: 479 if isinstance(identifier, str): 480 identifier = exp.parse_identifier(identifier, dialect=dialect) 481 482 if not normalize: 483 return identifier 484 485 # this is used for normalize_identifier, bigquery has special rules pertaining tables 486 identifier.meta["is_table"] = is_table 487 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 488 489 490def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 491 if isinstance(schema, Schema): 492 return schema 493 494 return MappingSchema(schema, **kwargs) 495 496 497def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 498 if mapping is None: 499 return {} 500 elif isinstance(mapping, dict): 501 return mapping 502 elif isinstance(mapping, str): 503 col_name_type_strs = [x.strip() for x in mapping.split(",")] 504 return { 505 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 506 for name_type_str in col_name_type_strs 507 } 508 elif isinstance(mapping, list): 509 return {x.strip(): None for x in mapping} 510 511 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 512 513 514def flatten_schema( 515 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 516) -> t.List[t.List[str]]: 517 tables = [] 518 keys = keys or [] 519 depth = dict_depth(schema) - 1 if depth is None else depth 520 521 for k, v in schema.items(): 522 if depth == 1 or not isinstance(v, dict): 523 tables.append(keys + [k]) 524 elif depth >= 2: 525 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 526 527 return tables 528 529 530def nested_get( 531 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 532) -> t.Optional[t.Any]: 533 """ 534 Get a value for a nested dictionary. 535 536 Args: 537 d: the dictionary to search. 538 *path: tuples of (name, key), where: 539 `key` is the key in the dictionary to get. 540 `name` is a string to use in the error if `key` isn't found. 541 542 Returns: 543 The value or None if it doesn't exist. 544 """ 545 for name, key in path: 546 d = d.get(key) # type: ignore 547 if d is None: 548 if raise_on_missing: 549 name = "table" if name == "this" else name 550 raise ValueError(f"Unknown {name}: {key}") 551 return None 552 553 return d 554 555 556def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 557 """ 558 In-place set a value for a nested dictionary 559 560 Example: 561 >>> nested_set({}, ["top_key", "second_key"], "value") 562 {'top_key': {'second_key': 'value'}} 563 564 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 565 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 566 567 Args: 568 d: dictionary to update. 569 keys: the keys that makeup the path to `value`. 570 value: the value to set in the dictionary for the given key path. 571 572 Returns: 573 The (possibly) updated dictionary. 574 """ 575 if not keys: 576 return d 577 578 if len(keys) == 1: 579 d[keys[0]] = value 580 return d 581 582 subd = d 583 for key in keys[:-1]: 584 if key not in subd: 585 subd = subd.setdefault(key, {}) 586 else: 587 subd = subd[key] 588 589 subd[keys[-1]] = value 590 return d
19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether the schema is empty.""" 119 return True
Abstract base class for database schemas
24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find( 160 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 161 ) -> t.Optional[t.Any]: 162 """ 163 Returns the schema of a given table. 164 165 Args: 166 table: the target table. 167 raise_on_missing: whether to raise in case the schema is not found. 168 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 169 170 Returns: 171 The schema of the target table. 172 """ 173 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 174 value, trie = in_trie(self.mapping_trie, parts) 175 176 if value == TrieResult.FAILED: 177 return None 178 179 if value == TrieResult.PREFIX: 180 possibilities = flatten_schema(trie) 181 182 if len(possibilities) == 1: 183 parts.extend(possibilities[0]) 184 else: 185 message = ", ".join(".".join(parts) for parts in possibilities) 186 if raise_on_missing: 187 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 188 return None 189 190 return self.nested_get(parts, raise_on_missing=raise_on_missing) 191 192 def nested_get( 193 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 194 ) -> t.Optional[t.Any]: 195 return nested_get( 196 d or self.mapping, 197 *zip(self.supported_table_args, reversed(parts)), 198 raise_on_missing=raise_on_missing, 199 )
140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args
159 def find( 160 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 161 ) -> t.Optional[t.Any]: 162 """ 163 Returns the schema of a given table. 164 165 Args: 166 table: the target table. 167 raise_on_missing: whether to raise in case the schema is not found. 168 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 169 170 Returns: 171 The schema of the target table. 172 """ 173 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 174 value, trie = in_trie(self.mapping_trie, parts) 175 176 if value == TrieResult.FAILED: 177 return None 178 179 if value == TrieResult.PREFIX: 180 possibilities = flatten_schema(trie) 181 182 if len(possibilities) == 1: 183 parts.extend(possibilities[0]) 184 else: 185 message = ", ".join(".".join(parts) for parts in possibilities) 186 if raise_on_missing: 187 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 188 return None 189 190 return self.nested_get(parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
str
types to theirDataType
equivalents.
Returns:
The schema of the target table.
202class MappingSchema(AbstractMappingSchema, Schema): 203 """ 204 Schema based on a nested mapping. 205 206 Args: 207 schema: Mapping in one of the following forms: 208 1. {table: {col: type}} 209 2. {db: {table: {col: type}}} 210 3. {catalog: {db: {table: {col: type}}}} 211 4. None - Tables will be added later 212 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 213 are assumed to be visible. The nesting should mirror that of the schema: 214 1. {table: set(*cols)}} 215 2. {db: {table: set(*cols)}}} 216 3. {catalog: {db: {table: set(*cols)}}}} 217 dialect: The dialect to be used for custom type mappings & parsing string arguments. 218 normalize: Whether to normalize identifier names according to the given dialect or not. 219 """ 220 221 def __init__( 222 self, 223 schema: t.Optional[t.Dict] = None, 224 visible: t.Optional[t.Dict] = None, 225 dialect: DialectType = None, 226 normalize: bool = True, 227 ) -> None: 228 self.dialect = dialect 229 self.visible = {} if visible is None else visible 230 self.normalize = normalize 231 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 232 self._depth = 0 233 schema = {} if schema is None else schema 234 235 super().__init__(self._normalize(schema) if self.normalize else schema) 236 237 @classmethod 238 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 239 return MappingSchema( 240 schema=mapping_schema.mapping, 241 visible=mapping_schema.visible, 242 dialect=mapping_schema.dialect, 243 normalize=mapping_schema.normalize, 244 ) 245 246 def find( 247 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 248 ) -> t.Optional[t.Any]: 249 schema = super().find( 250 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 251 ) 252 if ensure_data_types and isinstance(schema, dict): 253 schema = { 254 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 255 for col, dtype in schema.items() 256 } 257 258 return schema 259 260 def copy(self, **kwargs) -> MappingSchema: 261 return MappingSchema( 262 **{ # type: ignore 263 "schema": self.mapping.copy(), 264 "visible": self.visible.copy(), 265 "dialect": self.dialect, 266 "normalize": self.normalize, 267 **kwargs, 268 } 269 ) 270 271 def add_table( 272 self, 273 table: exp.Table | str, 274 column_mapping: t.Optional[ColumnMapping] = None, 275 dialect: DialectType = None, 276 normalize: t.Optional[bool] = None, 277 match_depth: bool = True, 278 ) -> None: 279 """ 280 Register or update a table. Updates are only performed if a new column mapping is provided. 281 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 282 283 Args: 284 table: the `Table` expression instance or string representing the table. 285 column_mapping: a column mapping that describes the structure of the table. 286 dialect: the SQL dialect that will be used to parse `table` if it's a string. 287 normalize: whether to normalize identifiers according to the dialect of interest. 288 match_depth: whether to enforce that the table must match the schema's depth or not. 289 """ 290 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 291 292 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 293 raise SchemaError( 294 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 295 f"schema's nesting level: {self.depth()}." 296 ) 297 298 normalized_column_mapping = { 299 self._normalize_name(key, dialect=dialect, normalize=normalize): value 300 for key, value in ensure_column_mapping(column_mapping).items() 301 } 302 303 schema = self.find(normalized_table, raise_on_missing=False) 304 if schema and not normalized_column_mapping: 305 return 306 307 parts = self.table_parts(normalized_table) 308 309 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 310 new_trie([parts], self.mapping_trie) 311 312 def column_names( 313 self, 314 table: exp.Table | str, 315 only_visible: bool = False, 316 dialect: DialectType = None, 317 normalize: t.Optional[bool] = None, 318 ) -> t.List[str]: 319 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 320 321 schema = self.find(normalized_table) 322 if schema is None: 323 return [] 324 325 if not only_visible or not self.visible: 326 return list(schema) 327 328 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 329 return [col for col in schema if col in visible] 330 331 def get_column_type( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> exp.DataType: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 if table_schema: 346 column_type = table_schema.get(normalized_column_name) 347 348 if isinstance(column_type, exp.DataType): 349 return column_type 350 elif isinstance(column_type, str): 351 return self._to_data_type(column_type, dialect=dialect) 352 353 return exp.DataType.build("unknown") 354 355 def has_column( 356 self, 357 table: exp.Table | str, 358 column: exp.Column | str, 359 dialect: DialectType = None, 360 normalize: t.Optional[bool] = None, 361 ) -> bool: 362 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 363 364 normalized_column_name = self._normalize_name( 365 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 366 ) 367 368 table_schema = self.find(normalized_table, raise_on_missing=False) 369 return normalized_column_name in table_schema if table_schema else False 370 371 def _normalize(self, schema: t.Dict) -> t.Dict: 372 """ 373 Normalizes all identifiers in the schema. 374 375 Args: 376 schema: the schema to normalize. 377 378 Returns: 379 The normalized schema mapping. 380 """ 381 normalized_mapping: t.Dict = {} 382 flattened_schema = flatten_schema(schema) 383 error_msg = "Table {} must match the schema's nesting level: {}." 384 385 for keys in flattened_schema: 386 columns = nested_get(schema, *zip(keys, keys)) 387 388 if not isinstance(columns, dict): 389 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 390 if isinstance(first(columns.values()), dict): 391 raise SchemaError( 392 error_msg.format( 393 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 394 ), 395 ) 396 397 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 398 for column_name, column_type in columns.items(): 399 nested_set( 400 normalized_mapping, 401 normalized_keys + [self._normalize_name(column_name)], 402 column_type, 403 ) 404 405 return normalized_mapping 406 407 def _normalize_table( 408 self, 409 table: exp.Table | str, 410 dialect: DialectType = None, 411 normalize: t.Optional[bool] = None, 412 ) -> exp.Table: 413 dialect = dialect or self.dialect 414 normalize = self.normalize if normalize is None else normalize 415 416 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 417 418 if normalize: 419 for arg in exp.TABLE_PARTS: 420 value = normalized_table.args.get(arg) 421 if isinstance(value, exp.Identifier): 422 normalized_table.set( 423 arg, 424 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 425 ) 426 427 return normalized_table 428 429 def _normalize_name( 430 self, 431 name: str | exp.Identifier, 432 dialect: DialectType = None, 433 is_table: bool = False, 434 normalize: t.Optional[bool] = None, 435 ) -> str: 436 return normalize_name( 437 name, 438 dialect=dialect or self.dialect, 439 is_table=is_table, 440 normalize=self.normalize if normalize is None else normalize, 441 ).name 442 443 def depth(self) -> int: 444 if not self.empty and not self._depth: 445 # The columns themselves are a mapping, but we don't want to include those 446 self._depth = super().depth() - 1 447 return self._depth 448 449 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 450 """ 451 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 452 453 Args: 454 schema_type: the type we want to convert. 455 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 456 457 Returns: 458 The resulting expression type. 459 """ 460 if schema_type not in self._type_mapping_cache: 461 dialect = dialect or self.dialect 462 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 463 464 try: 465 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 466 self._type_mapping_cache[schema_type] = expression 467 except AttributeError: 468 in_dialect = f" in dialect {dialect}" if dialect else "" 469 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 470 471 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
221 def __init__( 222 self, 223 schema: t.Optional[t.Dict] = None, 224 visible: t.Optional[t.Dict] = None, 225 dialect: DialectType = None, 226 normalize: bool = True, 227 ) -> None: 228 self.dialect = dialect 229 self.visible = {} if visible is None else visible 230 self.normalize = normalize 231 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 232 self._depth = 0 233 schema = {} if schema is None else schema 234 235 super().__init__(self._normalize(schema) if self.normalize else schema)
246 def find( 247 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 248 ) -> t.Optional[t.Any]: 249 schema = super().find( 250 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 251 ) 252 if ensure_data_types and isinstance(schema, dict): 253 schema = { 254 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 255 for col, dtype in schema.items() 256 } 257 258 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
str
types to theirDataType
equivalents.
Returns:
The schema of the target table.
271 def add_table( 272 self, 273 table: exp.Table | str, 274 column_mapping: t.Optional[ColumnMapping] = None, 275 dialect: DialectType = None, 276 normalize: t.Optional[bool] = None, 277 match_depth: bool = True, 278 ) -> None: 279 """ 280 Register or update a table. Updates are only performed if a new column mapping is provided. 281 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 282 283 Args: 284 table: the `Table` expression instance or string representing the table. 285 column_mapping: a column mapping that describes the structure of the table. 286 dialect: the SQL dialect that will be used to parse `table` if it's a string. 287 normalize: whether to normalize identifiers according to the dialect of interest. 288 match_depth: whether to enforce that the table must match the schema's depth or not. 289 """ 290 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 291 292 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 293 raise SchemaError( 294 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 295 f"schema's nesting level: {self.depth()}." 296 ) 297 298 normalized_column_mapping = { 299 self._normalize_name(key, dialect=dialect, normalize=normalize): value 300 for key, value in ensure_column_mapping(column_mapping).items() 301 } 302 303 schema = self.find(normalized_table, raise_on_missing=False) 304 if schema and not normalized_column_mapping: 305 return 306 307 parts = self.table_parts(normalized_table) 308 309 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 310 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
312 def column_names( 313 self, 314 table: exp.Table | str, 315 only_visible: bool = False, 316 dialect: DialectType = None, 317 normalize: t.Optional[bool] = None, 318 ) -> t.List[str]: 319 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 320 321 schema = self.find(normalized_table) 322 if schema is None: 323 return [] 324 325 if not only_visible or not self.visible: 326 return list(schema) 327 328 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 329 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
331 def get_column_type( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> exp.DataType: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 if table_schema: 346 column_type = table_schema.get(normalized_column_name) 347 348 if isinstance(column_type, exp.DataType): 349 return column_type 350 elif isinstance(column_type, str): 351 return self._to_data_type(column_type, dialect=dialect) 352 353 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
355 def has_column( 356 self, 357 table: exp.Table | str, 358 column: exp.Column | str, 359 dialect: DialectType = None, 360 normalize: t.Optional[bool] = None, 361 ) -> bool: 362 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 363 364 normalized_column_name = self._normalize_name( 365 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 366 ) 367 368 table_schema = self.find(normalized_table, raise_on_missing=False) 369 return normalized_column_name in table_schema if table_schema else False
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
474def normalize_name( 475 identifier: str | exp.Identifier, 476 dialect: DialectType = None, 477 is_table: bool = False, 478 normalize: t.Optional[bool] = True, 479) -> exp.Identifier: 480 if isinstance(identifier, str): 481 identifier = exp.parse_identifier(identifier, dialect=dialect) 482 483 if not normalize: 484 return identifier 485 486 # this is used for normalize_identifier, bigquery has special rules pertaining tables 487 identifier.meta["is_table"] = is_table 488 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
498def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 499 if mapping is None: 500 return {} 501 elif isinstance(mapping, dict): 502 return mapping 503 elif isinstance(mapping, str): 504 col_name_type_strs = [x.strip() for x in mapping.split(",")] 505 return { 506 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 507 for name_type_str in col_name_type_strs 508 } 509 elif isinstance(mapping, list): 510 return {x.strip(): None for x in mapping} 511 512 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
515def flatten_schema( 516 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 517) -> t.List[t.List[str]]: 518 tables = [] 519 keys = keys or [] 520 depth = dict_depth(schema) - 1 if depth is None else depth 521 522 for k, v in schema.items(): 523 if depth == 1 or not isinstance(v, dict): 524 tables.append(keys + [k]) 525 elif depth >= 2: 526 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 527 528 return tables
531def nested_get( 532 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 533) -> t.Optional[t.Any]: 534 """ 535 Get a value for a nested dictionary. 536 537 Args: 538 d: the dictionary to search. 539 *path: tuples of (name, key), where: 540 `key` is the key in the dictionary to get. 541 `name` is a string to use in the error if `key` isn't found. 542 543 Returns: 544 The value or None if it doesn't exist. 545 """ 546 for name, key in path: 547 d = d.get(key) # type: ignore 548 if d is None: 549 if raise_on_missing: 550 name = "table" if name == "this" else name 551 raise ValueError(f"Unknown {name}: {key}") 552 return None 553 554 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
557def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 558 """ 559 In-place set a value for a nested dictionary 560 561 Example: 562 >>> nested_set({}, ["top_key", "second_key"], "value") 563 {'top_key': {'second_key': 'value'}} 564 565 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 566 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 567 568 Args: 569 d: dictionary to update. 570 keys: the keys that makeup the path to `value`. 571 value: the value to set in the dictionary for the given key path. 572 573 Returns: 574 The (possibly) updated dictionary. 575 """ 576 if not keys: 577 return d 578 579 if len(keys) == 1: 580 d[keys[0]] = value 581 return d 582 583 subd = d 584 for key in keys[:-1]: 585 if key not in subd: 586 subd = subd.setdefault(key, {}) 587 else: 588 subd = subd[key] 589 590 subd[keys[-1]] = value 591 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.