sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot._typing import T 9from sqlglot.dialects.dialect import Dialect 10from sqlglot.errors import ParseError, SchemaError 11from sqlglot.helper import dict_depth 12from sqlglot.trie import TrieResult, in_trie, new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot.dataframe.sql.types import StructType 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 19 20TABLE_ARGS = ("this", "db", "catalog") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 match_depth: bool = True, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 40 41 Args: 42 table: the `Table` expression instance or string representing the table. 43 column_mapping: a column mapping that describes the structure of the table. 44 dialect: the SQL dialect that will be used to parse `table` if it's a string. 45 normalize: whether to normalize identifiers according to the dialect of interest. 46 match_depth: whether to enforce that the table must match the schema's depth or not. 47 """ 48 49 @abc.abstractmethod 50 def column_names( 51 self, 52 table: exp.Table | str, 53 only_visible: bool = False, 54 dialect: DialectType = None, 55 normalize: t.Optional[bool] = None, 56 ) -> t.List[str]: 57 """ 58 Get the column names for a table. 59 60 Args: 61 table: the `Table` expression instance. 62 only_visible: whether to include invisible columns. 63 dialect: the SQL dialect that will be used to parse `table` if it's a string. 64 normalize: whether to normalize identifiers according to the dialect of interest. 65 66 Returns: 67 The list of column names. 68 """ 69 70 @abc.abstractmethod 71 def get_column_type( 72 self, 73 table: exp.Table | str, 74 column: exp.Column, 75 dialect: DialectType = None, 76 normalize: t.Optional[bool] = None, 77 ) -> exp.DataType: 78 """ 79 Get the `sqlglot.exp.DataType` type of a column in the schema. 80 81 Args: 82 table: the source table. 83 column: the target column. 84 dialect: the SQL dialect that will be used to parse `table` if it's a string. 85 normalize: whether to normalize identifiers according to the dialect of interest. 86 87 Returns: 88 The resulting column type. 89 """ 90 91 @property 92 @abc.abstractmethod 93 def supported_table_args(self) -> t.Tuple[str, ...]: 94 """ 95 Table arguments this schema support, e.g. `("this", "db", "catalog")` 96 """ 97 98 @property 99 def empty(self) -> bool: 100 """Returns whether or not the schema is empty.""" 101 return True 102 103 104class AbstractMappingSchema(t.Generic[T]): 105 def __init__( 106 self, 107 mapping: t.Optional[t.Dict] = None, 108 ) -> None: 109 self.mapping = mapping or {} 110 self.mapping_trie = new_trie( 111 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 112 ) 113 self._supported_table_args: t.Tuple[str, ...] = tuple() 114 115 @property 116 def empty(self) -> bool: 117 return not self.mapping 118 119 def depth(self) -> int: 120 return dict_depth(self.mapping) 121 122 @property 123 def supported_table_args(self) -> t.Tuple[str, ...]: 124 if not self._supported_table_args and self.mapping: 125 depth = self.depth() 126 127 if not depth: # None 128 self._supported_table_args = tuple() 129 elif 1 <= depth <= 3: 130 self._supported_table_args = TABLE_ARGS[:depth] 131 else: 132 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 133 134 return self._supported_table_args 135 136 def table_parts(self, table: exp.Table) -> t.List[str]: 137 if isinstance(table.this, exp.ReadCSV): 138 return [table.this.name] 139 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 140 141 def find( 142 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 143 ) -> t.Optional[T]: 144 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 145 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 146 147 if value == TrieResult.FAILED: 148 return None 149 150 if value == TrieResult.PREFIX: 151 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 152 153 if len(possibilities) == 1: 154 parts.extend(possibilities[0]) 155 else: 156 message = ", ".join(".".join(parts) for parts in possibilities) 157 if raise_on_missing: 158 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 159 return None 160 161 return self.nested_get(parts, raise_on_missing=raise_on_missing) 162 163 def nested_get( 164 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 165 ) -> t.Optional[t.Any]: 166 return nested_get( 167 d or self.mapping, 168 *zip(self.supported_table_args, reversed(parts)), 169 raise_on_missing=raise_on_missing, 170 ) 171 172 173class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 174 """ 175 Schema based on a nested mapping. 176 177 Args: 178 schema: Mapping in one of the following forms: 179 1. {table: {col: type}} 180 2. {db: {table: {col: type}}} 181 3. {catalog: {db: {table: {col: type}}}} 182 4. None - Tables will be added later 183 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 184 are assumed to be visible. The nesting should mirror that of the schema: 185 1. {table: set(*cols)}} 186 2. {db: {table: set(*cols)}}} 187 3. {catalog: {db: {table: set(*cols)}}}} 188 dialect: The dialect to be used for custom type mappings & parsing string arguments. 189 normalize: Whether to normalize identifier names according to the given dialect or not. 190 """ 191 192 def __init__( 193 self, 194 schema: t.Optional[t.Dict] = None, 195 visible: t.Optional[t.Dict] = None, 196 dialect: DialectType = None, 197 normalize: bool = True, 198 ) -> None: 199 self.dialect = dialect 200 self.visible = visible or {} 201 self.normalize = normalize 202 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 203 self._depth = 0 204 205 super().__init__(self._normalize(schema or {})) 206 207 @classmethod 208 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 209 return MappingSchema( 210 schema=mapping_schema.mapping, 211 visible=mapping_schema.visible, 212 dialect=mapping_schema.dialect, 213 normalize=mapping_schema.normalize, 214 ) 215 216 def copy(self, **kwargs) -> MappingSchema: 217 return MappingSchema( 218 **{ # type: ignore 219 "schema": self.mapping.copy(), 220 "visible": self.visible.copy(), 221 "dialect": self.dialect, 222 "normalize": self.normalize, 223 **kwargs, 224 } 225 ) 226 227 def add_table( 228 self, 229 table: exp.Table | str, 230 column_mapping: t.Optional[ColumnMapping] = None, 231 dialect: DialectType = None, 232 normalize: t.Optional[bool] = None, 233 match_depth: bool = True, 234 ) -> None: 235 """ 236 Register or update a table. Updates are only performed if a new column mapping is provided. 237 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 238 239 Args: 240 table: the `Table` expression instance or string representing the table. 241 column_mapping: a column mapping that describes the structure of the table. 242 dialect: the SQL dialect that will be used to parse `table` if it's a string. 243 normalize: whether to normalize identifiers according to the dialect of interest. 244 match_depth: whether to enforce that the table must match the schema's depth or not. 245 """ 246 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 247 248 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 249 raise SchemaError( 250 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 251 f"schema's nesting level: {self.depth()}." 252 ) 253 254 normalized_column_mapping = { 255 self._normalize_name(key, dialect=dialect, normalize=normalize): value 256 for key, value in ensure_column_mapping(column_mapping).items() 257 } 258 259 schema = self.find(normalized_table, raise_on_missing=False) 260 if schema and not normalized_column_mapping: 261 return 262 263 parts = self.table_parts(normalized_table) 264 265 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 266 new_trie([parts], self.mapping_trie) 267 268 def column_names( 269 self, 270 table: exp.Table | str, 271 only_visible: bool = False, 272 dialect: DialectType = None, 273 normalize: t.Optional[bool] = None, 274 ) -> t.List[str]: 275 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 276 277 schema = self.find(normalized_table) 278 if schema is None: 279 return [] 280 281 if not only_visible or not self.visible: 282 return list(schema) 283 284 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 285 return [col for col in schema if col in visible] 286 287 def get_column_type( 288 self, 289 table: exp.Table | str, 290 column: exp.Column, 291 dialect: DialectType = None, 292 normalize: t.Optional[bool] = None, 293 ) -> exp.DataType: 294 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 295 296 normalized_column_name = self._normalize_name( 297 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 298 ) 299 300 table_schema = self.find(normalized_table, raise_on_missing=False) 301 if table_schema: 302 column_type = table_schema.get(normalized_column_name) 303 304 if isinstance(column_type, exp.DataType): 305 return column_type 306 elif isinstance(column_type, str): 307 return self._to_data_type(column_type.upper(), dialect=dialect) 308 309 return exp.DataType.build("unknown") 310 311 def _normalize(self, schema: t.Dict) -> t.Dict: 312 """ 313 Normalizes all identifiers in the schema. 314 315 Args: 316 schema: the schema to normalize. 317 318 Returns: 319 The normalized schema mapping. 320 """ 321 normalized_mapping: t.Dict = {} 322 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 323 324 for keys in flattened_schema: 325 columns = nested_get(schema, *zip(keys, keys)) 326 327 if not isinstance(columns, dict): 328 raise SchemaError( 329 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 330 ) 331 332 normalized_keys = [ 333 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 334 ] 335 for column_name, column_type in columns.items(): 336 nested_set( 337 normalized_mapping, 338 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 339 column_type, 340 ) 341 342 return normalized_mapping 343 344 def _normalize_table( 345 self, 346 table: exp.Table | str, 347 dialect: DialectType = None, 348 normalize: t.Optional[bool] = None, 349 ) -> exp.Table: 350 normalized_table = exp.maybe_parse( 351 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 352 ) 353 354 for arg in TABLE_ARGS: 355 value = normalized_table.args.get(arg) 356 if isinstance(value, (str, exp.Identifier)): 357 normalized_table.set( 358 arg, 359 exp.to_identifier( 360 self._normalize_name( 361 value, dialect=dialect, is_table=True, normalize=normalize 362 ) 363 ), 364 ) 365 366 return normalized_table 367 368 def _normalize_name( 369 self, 370 name: str | exp.Identifier, 371 dialect: DialectType = None, 372 is_table: bool = False, 373 normalize: t.Optional[bool] = None, 374 ) -> str: 375 return normalize_name( 376 name, 377 dialect=dialect or self.dialect, 378 is_table=is_table, 379 normalize=self.normalize if normalize is None else normalize, 380 ) 381 382 def depth(self) -> int: 383 if not self.empty and not self._depth: 384 # The columns themselves are a mapping, but we don't want to include those 385 self._depth = super().depth() - 1 386 return self._depth 387 388 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 389 """ 390 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 391 392 Args: 393 schema_type: the type we want to convert. 394 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 395 396 Returns: 397 The resulting expression type. 398 """ 399 if schema_type not in self._type_mapping_cache: 400 dialect = dialect or self.dialect 401 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 402 403 try: 404 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 405 self._type_mapping_cache[schema_type] = expression 406 except AttributeError: 407 in_dialect = f" in dialect {dialect}" if dialect else "" 408 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 409 410 return self._type_mapping_cache[schema_type] 411 412 413def normalize_name( 414 name: str | exp.Identifier, 415 dialect: DialectType = None, 416 is_table: bool = False, 417 normalize: t.Optional[bool] = True, 418) -> str: 419 try: 420 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 421 except ParseError: 422 return name if isinstance(name, str) else name.name 423 424 name = identifier.name 425 if not normalize: 426 return name 427 428 # This can be useful for normalize_identifier 429 identifier.meta["is_table"] = is_table 430 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 431 432 433def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 434 if isinstance(schema, Schema): 435 return schema 436 437 return MappingSchema(schema, **kwargs) 438 439 440def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 441 if mapping is None: 442 return {} 443 elif isinstance(mapping, dict): 444 return mapping 445 elif isinstance(mapping, str): 446 col_name_type_strs = [x.strip() for x in mapping.split(",")] 447 return { 448 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 449 for name_type_str in col_name_type_strs 450 } 451 # Check if mapping looks like a DataFrame StructType 452 elif hasattr(mapping, "simpleString"): 453 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 454 elif isinstance(mapping, list): 455 return {x.strip(): None for x in mapping} 456 457 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 458 459 460def flatten_schema( 461 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 462) -> t.List[t.List[str]]: 463 tables = [] 464 keys = keys or [] 465 466 for k, v in schema.items(): 467 if depth >= 2: 468 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 469 elif depth == 1: 470 tables.append(keys + [k]) 471 472 return tables 473 474 475def nested_get( 476 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 477) -> t.Optional[t.Any]: 478 """ 479 Get a value for a nested dictionary. 480 481 Args: 482 d: the dictionary to search. 483 *path: tuples of (name, key), where: 484 `key` is the key in the dictionary to get. 485 `name` is a string to use in the error if `key` isn't found. 486 487 Returns: 488 The value or None if it doesn't exist. 489 """ 490 for name, key in path: 491 d = d.get(key) # type: ignore 492 if d is None: 493 if raise_on_missing: 494 name = "table" if name == "this" else name 495 raise ValueError(f"Unknown {name}: {key}") 496 return None 497 498 return d 499 500 501def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 502 """ 503 In-place set a value for a nested dictionary 504 505 Example: 506 >>> nested_set({}, ["top_key", "second_key"], "value") 507 {'top_key': {'second_key': 'value'}} 508 509 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 510 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 511 512 Args: 513 d: dictionary to update. 514 keys: the keys that makeup the path to `value`. 515 value: the value to set in the dictionary for the given key path. 516 517 Returns: 518 The (possibly) updated dictionary. 519 """ 520 if not keys: 521 return d 522 523 if len(keys) == 1: 524 d[keys[0]] = value 525 return d 526 527 subd = d 528 for key in keys[:-1]: 529 if key not in subd: 530 subd = subd.setdefault(key, {}) 531 else: 532 subd = subd[key] 533 534 subd[keys[-1]] = value 535 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 dialect: DialectType 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """ 49 50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """ 70 71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """ 91 92 @property 93 @abc.abstractmethod 94 def supported_table_args(self) -> t.Tuple[str, ...]: 95 """ 96 Table arguments this schema support, e.g. `("this", "db", "catalog")` 97 """ 98 99 @property 100 def empty(self) -> bool: 101 """Returns whether or not the schema is empty.""" 102 return True
Abstract base class for database schemas
29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
105class AbstractMappingSchema(t.Generic[T]): 106 def __init__( 107 self, 108 mapping: t.Optional[t.Dict] = None, 109 ) -> None: 110 self.mapping = mapping or {} 111 self.mapping_trie = new_trie( 112 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 113 ) 114 self._supported_table_args: t.Tuple[str, ...] = tuple() 115 116 @property 117 def empty(self) -> bool: 118 return not self.mapping 119 120 def depth(self) -> int: 121 return dict_depth(self.mapping) 122 123 @property 124 def supported_table_args(self) -> t.Tuple[str, ...]: 125 if not self._supported_table_args and self.mapping: 126 depth = self.depth() 127 128 if not depth: # None 129 self._supported_table_args = tuple() 130 elif 1 <= depth <= 3: 131 self._supported_table_args = TABLE_ARGS[:depth] 132 else: 133 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 134 135 return self._supported_table_args 136 137 def table_parts(self, table: exp.Table) -> t.List[str]: 138 if isinstance(table.this, exp.ReadCSV): 139 return [table.this.name] 140 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 141 142 def find( 143 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 144 ) -> t.Optional[T]: 145 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 146 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 147 148 if value == TrieResult.FAILED: 149 return None 150 151 if value == TrieResult.PREFIX: 152 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 153 154 if len(possibilities) == 1: 155 parts.extend(possibilities[0]) 156 else: 157 message = ", ".join(".".join(parts) for parts in possibilities) 158 if raise_on_missing: 159 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 160 return None 161 162 return self.nested_get(parts, raise_on_missing=raise_on_missing) 163 164 def nested_get( 165 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 166 ) -> t.Optional[t.Any]: 167 return nested_get( 168 d or self.mapping, 169 *zip(self.supported_table_args, reversed(parts)), 170 raise_on_missing=raise_on_missing, 171 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
142 def find( 143 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 144 ) -> t.Optional[T]: 145 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 146 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 147 148 if value == TrieResult.FAILED: 149 return None 150 151 if value == TrieResult.PREFIX: 152 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 153 154 if len(possibilities) == 1: 155 parts.extend(possibilities[0]) 156 else: 157 message = ", ".join(".".join(parts) for parts in possibilities) 158 if raise_on_missing: 159 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 160 return None 161 162 return self.nested_get(parts, raise_on_missing=raise_on_missing)
174class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 175 """ 176 Schema based on a nested mapping. 177 178 Args: 179 schema: Mapping in one of the following forms: 180 1. {table: {col: type}} 181 2. {db: {table: {col: type}}} 182 3. {catalog: {db: {table: {col: type}}}} 183 4. None - Tables will be added later 184 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 185 are assumed to be visible. The nesting should mirror that of the schema: 186 1. {table: set(*cols)}} 187 2. {db: {table: set(*cols)}}} 188 3. {catalog: {db: {table: set(*cols)}}}} 189 dialect: The dialect to be used for custom type mappings & parsing string arguments. 190 normalize: Whether to normalize identifier names according to the given dialect or not. 191 """ 192 193 def __init__( 194 self, 195 schema: t.Optional[t.Dict] = None, 196 visible: t.Optional[t.Dict] = None, 197 dialect: DialectType = None, 198 normalize: bool = True, 199 ) -> None: 200 self.dialect = dialect 201 self.visible = visible or {} 202 self.normalize = normalize 203 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 204 self._depth = 0 205 206 super().__init__(self._normalize(schema or {})) 207 208 @classmethod 209 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 210 return MappingSchema( 211 schema=mapping_schema.mapping, 212 visible=mapping_schema.visible, 213 dialect=mapping_schema.dialect, 214 normalize=mapping_schema.normalize, 215 ) 216 217 def copy(self, **kwargs) -> MappingSchema: 218 return MappingSchema( 219 **{ # type: ignore 220 "schema": self.mapping.copy(), 221 "visible": self.visible.copy(), 222 "dialect": self.dialect, 223 "normalize": self.normalize, 224 **kwargs, 225 } 226 ) 227 228 def add_table( 229 self, 230 table: exp.Table | str, 231 column_mapping: t.Optional[ColumnMapping] = None, 232 dialect: DialectType = None, 233 normalize: t.Optional[bool] = None, 234 match_depth: bool = True, 235 ) -> None: 236 """ 237 Register or update a table. Updates are only performed if a new column mapping is provided. 238 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 239 240 Args: 241 table: the `Table` expression instance or string representing the table. 242 column_mapping: a column mapping that describes the structure of the table. 243 dialect: the SQL dialect that will be used to parse `table` if it's a string. 244 normalize: whether to normalize identifiers according to the dialect of interest. 245 match_depth: whether to enforce that the table must match the schema's depth or not. 246 """ 247 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 248 249 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 250 raise SchemaError( 251 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 252 f"schema's nesting level: {self.depth()}." 253 ) 254 255 normalized_column_mapping = { 256 self._normalize_name(key, dialect=dialect, normalize=normalize): value 257 for key, value in ensure_column_mapping(column_mapping).items() 258 } 259 260 schema = self.find(normalized_table, raise_on_missing=False) 261 if schema and not normalized_column_mapping: 262 return 263 264 parts = self.table_parts(normalized_table) 265 266 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 267 new_trie([parts], self.mapping_trie) 268 269 def column_names( 270 self, 271 table: exp.Table | str, 272 only_visible: bool = False, 273 dialect: DialectType = None, 274 normalize: t.Optional[bool] = None, 275 ) -> t.List[str]: 276 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 277 278 schema = self.find(normalized_table) 279 if schema is None: 280 return [] 281 282 if not only_visible or not self.visible: 283 return list(schema) 284 285 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 286 return [col for col in schema if col in visible] 287 288 def get_column_type( 289 self, 290 table: exp.Table | str, 291 column: exp.Column, 292 dialect: DialectType = None, 293 normalize: t.Optional[bool] = None, 294 ) -> exp.DataType: 295 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 296 297 normalized_column_name = self._normalize_name( 298 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 299 ) 300 301 table_schema = self.find(normalized_table, raise_on_missing=False) 302 if table_schema: 303 column_type = table_schema.get(normalized_column_name) 304 305 if isinstance(column_type, exp.DataType): 306 return column_type 307 elif isinstance(column_type, str): 308 return self._to_data_type(column_type.upper(), dialect=dialect) 309 310 return exp.DataType.build("unknown") 311 312 def _normalize(self, schema: t.Dict) -> t.Dict: 313 """ 314 Normalizes all identifiers in the schema. 315 316 Args: 317 schema: the schema to normalize. 318 319 Returns: 320 The normalized schema mapping. 321 """ 322 normalized_mapping: t.Dict = {} 323 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 324 325 for keys in flattened_schema: 326 columns = nested_get(schema, *zip(keys, keys)) 327 328 if not isinstance(columns, dict): 329 raise SchemaError( 330 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 331 ) 332 333 normalized_keys = [ 334 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 335 ] 336 for column_name, column_type in columns.items(): 337 nested_set( 338 normalized_mapping, 339 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 340 column_type, 341 ) 342 343 return normalized_mapping 344 345 def _normalize_table( 346 self, 347 table: exp.Table | str, 348 dialect: DialectType = None, 349 normalize: t.Optional[bool] = None, 350 ) -> exp.Table: 351 normalized_table = exp.maybe_parse( 352 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 353 ) 354 355 for arg in TABLE_ARGS: 356 value = normalized_table.args.get(arg) 357 if isinstance(value, (str, exp.Identifier)): 358 normalized_table.set( 359 arg, 360 exp.to_identifier( 361 self._normalize_name( 362 value, dialect=dialect, is_table=True, normalize=normalize 363 ) 364 ), 365 ) 366 367 return normalized_table 368 369 def _normalize_name( 370 self, 371 name: str | exp.Identifier, 372 dialect: DialectType = None, 373 is_table: bool = False, 374 normalize: t.Optional[bool] = None, 375 ) -> str: 376 return normalize_name( 377 name, 378 dialect=dialect or self.dialect, 379 is_table=is_table, 380 normalize=self.normalize if normalize is None else normalize, 381 ) 382 383 def depth(self) -> int: 384 if not self.empty and not self._depth: 385 # The columns themselves are a mapping, but we don't want to include those 386 self._depth = super().depth() - 1 387 return self._depth 388 389 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 390 """ 391 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 392 393 Args: 394 schema_type: the type we want to convert. 395 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 396 397 Returns: 398 The resulting expression type. 399 """ 400 if schema_type not in self._type_mapping_cache: 401 dialect = dialect or self.dialect 402 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 403 404 try: 405 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 406 self._type_mapping_cache[schema_type] = expression 407 except AttributeError: 408 in_dialect = f" in dialect {dialect}" if dialect else "" 409 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 410 411 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
193 def __init__( 194 self, 195 schema: t.Optional[t.Dict] = None, 196 visible: t.Optional[t.Dict] = None, 197 dialect: DialectType = None, 198 normalize: bool = True, 199 ) -> None: 200 self.dialect = dialect 201 self.visible = visible or {} 202 self.normalize = normalize 203 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 204 self._depth = 0 205 206 super().__init__(self._normalize(schema or {}))
414def normalize_name( 415 name: str | exp.Identifier, 416 dialect: DialectType = None, 417 is_table: bool = False, 418 normalize: t.Optional[bool] = True, 419) -> str: 420 try: 421 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 422 except ParseError: 423 return name if isinstance(name, str) else name.name 424 425 name = identifier.name 426 if not normalize: 427 return name 428 429 # This can be useful for normalize_identifier 430 identifier.meta["is_table"] = is_table 431 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
441def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 442 if mapping is None: 443 return {} 444 elif isinstance(mapping, dict): 445 return mapping 446 elif isinstance(mapping, str): 447 col_name_type_strs = [x.strip() for x in mapping.split(",")] 448 return { 449 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 450 for name_type_str in col_name_type_strs 451 } 452 # Check if mapping looks like a DataFrame StructType 453 elif hasattr(mapping, "simpleString"): 454 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 455 elif isinstance(mapping, list): 456 return {x.strip(): None for x in mapping} 457 458 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
461def flatten_schema( 462 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 463) -> t.List[t.List[str]]: 464 tables = [] 465 keys = keys or [] 466 467 for k, v in schema.items(): 468 if depth >= 2: 469 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 470 elif depth == 1: 471 tables.append(keys + [k]) 472 473 return tables
476def nested_get( 477 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 478) -> t.Optional[t.Any]: 479 """ 480 Get a value for a nested dictionary. 481 482 Args: 483 d: the dictionary to search. 484 *path: tuples of (name, key), where: 485 `key` is the key in the dictionary to get. 486 `name` is a string to use in the error if `key` isn't found. 487 488 Returns: 489 The value or None if it doesn't exist. 490 """ 491 for name, key in path: 492 d = d.get(key) # type: ignore 493 if d is None: 494 if raise_on_missing: 495 name = "table" if name == "this" else name 496 raise ValueError(f"Unknown {name}: {key}") 497 return None 498 499 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
502def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 503 """ 504 In-place set a value for a nested dictionary 505 506 Example: 507 >>> nested_set({}, ["top_key", "second_key"], "value") 508 {'top_key': {'second_key': 'value'}} 509 510 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 511 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 512 513 Args: 514 d: dictionary to update. 515 keys: the keys that makeup the path to `value`. 516 value: the value to set in the dictionary for the given key path. 517 518 Returns: 519 The (possibly) updated dictionary. 520 """ 521 if not keys: 522 return d 523 524 if len(keys) == 1: 525 d[keys[0]] = value 526 return d 527 528 subd = d 529 for key in keys[:-1]: 530 if key not in subd: 531 subd = subd.setdefault(key, {}) 532 else: 533 subd = subd[key] 534 535 subd[keys[-1]] = value 536 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.