sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot._typing import T 9from sqlglot.dialects.dialect import Dialect 10from sqlglot.errors import ParseError, SchemaError 11from sqlglot.helper import dict_depth 12from sqlglot.trie import TrieResult, in_trie, new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot.dataframe.sql.types import StructType 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 19 20TABLE_ARGS = ("this", "db", "catalog") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 match_depth: bool = True, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 40 41 Args: 42 table: the `Table` expression instance or string representing the table. 43 column_mapping: a column mapping that describes the structure of the table. 44 dialect: the SQL dialect that will be used to parse `table` if it's a string. 45 normalize: whether to normalize identifiers according to the dialect of interest. 46 match_depth: whether to enforce that the table must match the schema's depth or not. 47 """ 48 49 @abc.abstractmethod 50 def column_names( 51 self, 52 table: exp.Table | str, 53 only_visible: bool = False, 54 dialect: DialectType = None, 55 normalize: t.Optional[bool] = None, 56 ) -> t.List[str]: 57 """ 58 Get the column names for a table. 59 60 Args: 61 table: the `Table` expression instance. 62 only_visible: whether to include invisible columns. 63 dialect: the SQL dialect that will be used to parse `table` if it's a string. 64 normalize: whether to normalize identifiers according to the dialect of interest. 65 66 Returns: 67 The list of column names. 68 """ 69 70 @abc.abstractmethod 71 def get_column_type( 72 self, 73 table: exp.Table | str, 74 column: exp.Column, 75 dialect: DialectType = None, 76 normalize: t.Optional[bool] = None, 77 ) -> exp.DataType: 78 """ 79 Get the `sqlglot.exp.DataType` type of a column in the schema. 80 81 Args: 82 table: the source table. 83 column: the target column. 84 dialect: the SQL dialect that will be used to parse `table` if it's a string. 85 normalize: whether to normalize identifiers according to the dialect of interest. 86 87 Returns: 88 The resulting column type. 89 """ 90 91 @property 92 @abc.abstractmethod 93 def supported_table_args(self) -> t.Tuple[str, ...]: 94 """ 95 Table arguments this schema support, e.g. `("this", "db", "catalog")` 96 """ 97 98 @property 99 def empty(self) -> bool: 100 """Returns whether or not the schema is empty.""" 101 return True 102 103 104class AbstractMappingSchema(t.Generic[T]): 105 def __init__( 106 self, 107 mapping: t.Optional[t.Dict] = None, 108 ) -> None: 109 self.mapping = mapping or {} 110 self.mapping_trie = new_trie( 111 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 112 ) 113 self._supported_table_args: t.Tuple[str, ...] = tuple() 114 115 @property 116 def empty(self) -> bool: 117 return not self.mapping 118 119 def depth(self) -> int: 120 return dict_depth(self.mapping) 121 122 @property 123 def supported_table_args(self) -> t.Tuple[str, ...]: 124 if not self._supported_table_args and self.mapping: 125 depth = self.depth() 126 127 if not depth: # None 128 self._supported_table_args = tuple() 129 elif 1 <= depth <= 3: 130 self._supported_table_args = TABLE_ARGS[:depth] 131 else: 132 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 133 134 return self._supported_table_args 135 136 def table_parts(self, table: exp.Table) -> t.List[str]: 137 if isinstance(table.this, exp.ReadCSV): 138 return [table.this.name] 139 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 140 141 def find( 142 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 143 ) -> t.Optional[T]: 144 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 145 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 146 147 if value == TrieResult.FAILED: 148 return None 149 150 if value == TrieResult.PREFIX: 151 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 152 153 if len(possibilities) == 1: 154 parts.extend(possibilities[0]) 155 else: 156 message = ", ".join(".".join(parts) for parts in possibilities) 157 if raise_on_missing: 158 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 159 return None 160 161 return self.nested_get(parts, raise_on_missing=raise_on_missing) 162 163 def nested_get( 164 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 165 ) -> t.Optional[t.Any]: 166 return nested_get( 167 d or self.mapping, 168 *zip(self.supported_table_args, reversed(parts)), 169 raise_on_missing=raise_on_missing, 170 ) 171 172 173class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 174 """ 175 Schema based on a nested mapping. 176 177 Args: 178 schema: Mapping in one of the following forms: 179 1. {table: {col: type}} 180 2. {db: {table: {col: type}}} 181 3. {catalog: {db: {table: {col: type}}}} 182 4. None - Tables will be added later 183 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 184 are assumed to be visible. The nesting should mirror that of the schema: 185 1. {table: set(*cols)}} 186 2. {db: {table: set(*cols)}}} 187 3. {catalog: {db: {table: set(*cols)}}}} 188 dialect: The dialect to be used for custom type mappings & parsing string arguments. 189 normalize: Whether to normalize identifier names according to the given dialect or not. 190 """ 191 192 def __init__( 193 self, 194 schema: t.Optional[t.Dict] = None, 195 visible: t.Optional[t.Dict] = None, 196 dialect: DialectType = None, 197 normalize: bool = True, 198 ) -> None: 199 self.dialect = dialect 200 self.visible = visible or {} 201 self.normalize = normalize 202 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 203 self._depth = 0 204 205 super().__init__(self._normalize(schema or {})) 206 207 @classmethod 208 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 209 return MappingSchema( 210 schema=mapping_schema.mapping, 211 visible=mapping_schema.visible, 212 dialect=mapping_schema.dialect, 213 normalize=mapping_schema.normalize, 214 ) 215 216 def copy(self, **kwargs) -> MappingSchema: 217 return MappingSchema( 218 **{ # type: ignore 219 "schema": self.mapping.copy(), 220 "visible": self.visible.copy(), 221 "dialect": self.dialect, 222 "normalize": self.normalize, 223 **kwargs, 224 } 225 ) 226 227 def add_table( 228 self, 229 table: exp.Table | str, 230 column_mapping: t.Optional[ColumnMapping] = None, 231 dialect: DialectType = None, 232 normalize: t.Optional[bool] = None, 233 match_depth: bool = True, 234 ) -> None: 235 """ 236 Register or update a table. Updates are only performed if a new column mapping is provided. 237 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 238 239 Args: 240 table: the `Table` expression instance or string representing the table. 241 column_mapping: a column mapping that describes the structure of the table. 242 dialect: the SQL dialect that will be used to parse `table` if it's a string. 243 normalize: whether to normalize identifiers according to the dialect of interest. 244 match_depth: whether to enforce that the table must match the schema's depth or not. 245 """ 246 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 247 248 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 249 raise SchemaError( 250 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 251 f"schema's nesting level: {self.depth()}." 252 ) 253 254 normalized_column_mapping = { 255 self._normalize_name(key, dialect=dialect, normalize=normalize): value 256 for key, value in ensure_column_mapping(column_mapping).items() 257 } 258 259 schema = self.find(normalized_table, raise_on_missing=False) 260 if schema and not normalized_column_mapping: 261 return 262 263 parts = self.table_parts(normalized_table) 264 265 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 266 new_trie([parts], self.mapping_trie) 267 268 def column_names( 269 self, 270 table: exp.Table | str, 271 only_visible: bool = False, 272 dialect: DialectType = None, 273 normalize: t.Optional[bool] = None, 274 ) -> t.List[str]: 275 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 276 277 schema = self.find(normalized_table) 278 if schema is None: 279 return [] 280 281 if not only_visible or not self.visible: 282 return list(schema) 283 284 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 285 return [col for col in schema if col in visible] 286 287 def get_column_type( 288 self, 289 table: exp.Table | str, 290 column: exp.Column, 291 dialect: DialectType = None, 292 normalize: t.Optional[bool] = None, 293 ) -> exp.DataType: 294 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 295 296 normalized_column_name = self._normalize_name( 297 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 298 ) 299 300 table_schema = self.find(normalized_table, raise_on_missing=False) 301 if table_schema: 302 column_type = table_schema.get(normalized_column_name) 303 304 if isinstance(column_type, exp.DataType): 305 return column_type 306 elif isinstance(column_type, str): 307 return self._to_data_type(column_type.upper(), dialect=dialect) 308 309 return exp.DataType.build("unknown") 310 311 def _normalize(self, schema: t.Dict) -> t.Dict: 312 """ 313 Normalizes all identifiers in the schema. 314 315 Args: 316 schema: the schema to normalize. 317 318 Returns: 319 The normalized schema mapping. 320 """ 321 normalized_mapping: t.Dict = {} 322 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 323 324 for keys in flattened_schema: 325 columns = nested_get(schema, *zip(keys, keys)) 326 327 if not isinstance(columns, dict): 328 raise SchemaError( 329 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 330 ) 331 332 normalized_keys = [ 333 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 334 ] 335 for column_name, column_type in columns.items(): 336 nested_set( 337 normalized_mapping, 338 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 339 column_type, 340 ) 341 342 return normalized_mapping 343 344 def _normalize_table( 345 self, 346 table: exp.Table | str, 347 dialect: DialectType = None, 348 normalize: t.Optional[bool] = None, 349 ) -> exp.Table: 350 normalized_table = exp.maybe_parse( 351 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 352 ) 353 354 for arg in TABLE_ARGS: 355 value = normalized_table.args.get(arg) 356 if isinstance(value, (str, exp.Identifier)): 357 normalized_table.set( 358 arg, 359 exp.to_identifier( 360 self._normalize_name( 361 value, dialect=dialect, is_table=True, normalize=normalize 362 ) 363 ), 364 ) 365 366 return normalized_table 367 368 def _normalize_name( 369 self, 370 name: str | exp.Identifier, 371 dialect: DialectType = None, 372 is_table: bool = False, 373 normalize: t.Optional[bool] = None, 374 ) -> str: 375 return normalize_name( 376 name, 377 dialect=dialect or self.dialect, 378 is_table=is_table, 379 normalize=self.normalize if normalize is None else normalize, 380 ) 381 382 def depth(self) -> int: 383 if not self.empty and not self._depth: 384 # The columns themselves are a mapping, but we don't want to include those 385 self._depth = super().depth() - 1 386 return self._depth 387 388 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 389 """ 390 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 391 392 Args: 393 schema_type: the type we want to convert. 394 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 395 396 Returns: 397 The resulting expression type. 398 """ 399 if schema_type not in self._type_mapping_cache: 400 dialect = dialect or self.dialect 401 402 try: 403 expression = exp.DataType.build(schema_type, dialect=dialect) 404 self._type_mapping_cache[schema_type] = expression 405 except AttributeError: 406 in_dialect = f" in dialect {dialect}" if dialect else "" 407 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 408 409 return self._type_mapping_cache[schema_type] 410 411 412def normalize_name( 413 name: str | exp.Identifier, 414 dialect: DialectType = None, 415 is_table: bool = False, 416 normalize: t.Optional[bool] = True, 417) -> str: 418 try: 419 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 420 except ParseError: 421 return name if isinstance(name, str) else name.name 422 423 name = identifier.name 424 if not normalize: 425 return name 426 427 # This can be useful for normalize_identifier 428 identifier.meta["is_table"] = is_table 429 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 430 431 432def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 433 if isinstance(schema, Schema): 434 return schema 435 436 return MappingSchema(schema, **kwargs) 437 438 439def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 440 if mapping is None: 441 return {} 442 elif isinstance(mapping, dict): 443 return mapping 444 elif isinstance(mapping, str): 445 col_name_type_strs = [x.strip() for x in mapping.split(",")] 446 return { 447 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 448 for name_type_str in col_name_type_strs 449 } 450 # Check if mapping looks like a DataFrame StructType 451 elif hasattr(mapping, "simpleString"): 452 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 453 elif isinstance(mapping, list): 454 return {x.strip(): None for x in mapping} 455 456 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 457 458 459def flatten_schema( 460 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 461) -> t.List[t.List[str]]: 462 tables = [] 463 keys = keys or [] 464 465 for k, v in schema.items(): 466 if depth >= 2: 467 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 468 elif depth == 1: 469 tables.append(keys + [k]) 470 471 return tables 472 473 474def nested_get( 475 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 476) -> t.Optional[t.Any]: 477 """ 478 Get a value for a nested dictionary. 479 480 Args: 481 d: the dictionary to search. 482 *path: tuples of (name, key), where: 483 `key` is the key in the dictionary to get. 484 `name` is a string to use in the error if `key` isn't found. 485 486 Returns: 487 The value or None if it doesn't exist. 488 """ 489 for name, key in path: 490 d = d.get(key) # type: ignore 491 if d is None: 492 if raise_on_missing: 493 name = "table" if name == "this" else name 494 raise ValueError(f"Unknown {name}: {key}") 495 return None 496 497 return d 498 499 500def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 501 """ 502 In-place set a value for a nested dictionary 503 504 Example: 505 >>> nested_set({}, ["top_key", "second_key"], "value") 506 {'top_key': {'second_key': 'value'}} 507 508 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 509 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 510 511 Args: 512 d: dictionary to update. 513 keys: the keys that makeup the path to `value`. 514 value: the value to set in the dictionary for the given key path. 515 516 Returns: 517 The (possibly) updated dictionary. 518 """ 519 if not keys: 520 return d 521 522 if len(keys) == 1: 523 d[keys[0]] = value 524 return d 525 526 subd = d 527 for key in keys[:-1]: 528 if key not in subd: 529 subd = subd.setdefault(key, {}) 530 else: 531 subd = subd[key] 532 533 subd[keys[-1]] = value 534 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 dialect: DialectType 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """ 49 50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """ 70 71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """ 91 92 @property 93 @abc.abstractmethod 94 def supported_table_args(self) -> t.Tuple[str, ...]: 95 """ 96 Table arguments this schema support, e.g. `("this", "db", "catalog")` 97 """ 98 99 @property 100 def empty(self) -> bool: 101 """Returns whether or not the schema is empty.""" 102 return True
Abstract base class for database schemas
29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
105class AbstractMappingSchema(t.Generic[T]): 106 def __init__( 107 self, 108 mapping: t.Optional[t.Dict] = None, 109 ) -> None: 110 self.mapping = mapping or {} 111 self.mapping_trie = new_trie( 112 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 113 ) 114 self._supported_table_args: t.Tuple[str, ...] = tuple() 115 116 @property 117 def empty(self) -> bool: 118 return not self.mapping 119 120 def depth(self) -> int: 121 return dict_depth(self.mapping) 122 123 @property 124 def supported_table_args(self) -> t.Tuple[str, ...]: 125 if not self._supported_table_args and self.mapping: 126 depth = self.depth() 127 128 if not depth: # None 129 self._supported_table_args = tuple() 130 elif 1 <= depth <= 3: 131 self._supported_table_args = TABLE_ARGS[:depth] 132 else: 133 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 134 135 return self._supported_table_args 136 137 def table_parts(self, table: exp.Table) -> t.List[str]: 138 if isinstance(table.this, exp.ReadCSV): 139 return [table.this.name] 140 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 141 142 def find( 143 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 144 ) -> t.Optional[T]: 145 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 146 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 147 148 if value == TrieResult.FAILED: 149 return None 150 151 if value == TrieResult.PREFIX: 152 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 153 154 if len(possibilities) == 1: 155 parts.extend(possibilities[0]) 156 else: 157 message = ", ".join(".".join(parts) for parts in possibilities) 158 if raise_on_missing: 159 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 160 return None 161 162 return self.nested_get(parts, raise_on_missing=raise_on_missing) 163 164 def nested_get( 165 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 166 ) -> t.Optional[t.Any]: 167 return nested_get( 168 d or self.mapping, 169 *zip(self.supported_table_args, reversed(parts)), 170 raise_on_missing=raise_on_missing, 171 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
142 def find( 143 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 144 ) -> t.Optional[T]: 145 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 146 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 147 148 if value == TrieResult.FAILED: 149 return None 150 151 if value == TrieResult.PREFIX: 152 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 153 154 if len(possibilities) == 1: 155 parts.extend(possibilities[0]) 156 else: 157 message = ", ".join(".".join(parts) for parts in possibilities) 158 if raise_on_missing: 159 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 160 return None 161 162 return self.nested_get(parts, raise_on_missing=raise_on_missing)
174class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 175 """ 176 Schema based on a nested mapping. 177 178 Args: 179 schema: Mapping in one of the following forms: 180 1. {table: {col: type}} 181 2. {db: {table: {col: type}}} 182 3. {catalog: {db: {table: {col: type}}}} 183 4. None - Tables will be added later 184 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 185 are assumed to be visible. The nesting should mirror that of the schema: 186 1. {table: set(*cols)}} 187 2. {db: {table: set(*cols)}}} 188 3. {catalog: {db: {table: set(*cols)}}}} 189 dialect: The dialect to be used for custom type mappings & parsing string arguments. 190 normalize: Whether to normalize identifier names according to the given dialect or not. 191 """ 192 193 def __init__( 194 self, 195 schema: t.Optional[t.Dict] = None, 196 visible: t.Optional[t.Dict] = None, 197 dialect: DialectType = None, 198 normalize: bool = True, 199 ) -> None: 200 self.dialect = dialect 201 self.visible = visible or {} 202 self.normalize = normalize 203 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 204 self._depth = 0 205 206 super().__init__(self._normalize(schema or {})) 207 208 @classmethod 209 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 210 return MappingSchema( 211 schema=mapping_schema.mapping, 212 visible=mapping_schema.visible, 213 dialect=mapping_schema.dialect, 214 normalize=mapping_schema.normalize, 215 ) 216 217 def copy(self, **kwargs) -> MappingSchema: 218 return MappingSchema( 219 **{ # type: ignore 220 "schema": self.mapping.copy(), 221 "visible": self.visible.copy(), 222 "dialect": self.dialect, 223 "normalize": self.normalize, 224 **kwargs, 225 } 226 ) 227 228 def add_table( 229 self, 230 table: exp.Table | str, 231 column_mapping: t.Optional[ColumnMapping] = None, 232 dialect: DialectType = None, 233 normalize: t.Optional[bool] = None, 234 match_depth: bool = True, 235 ) -> None: 236 """ 237 Register or update a table. Updates are only performed if a new column mapping is provided. 238 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 239 240 Args: 241 table: the `Table` expression instance or string representing the table. 242 column_mapping: a column mapping that describes the structure of the table. 243 dialect: the SQL dialect that will be used to parse `table` if it's a string. 244 normalize: whether to normalize identifiers according to the dialect of interest. 245 match_depth: whether to enforce that the table must match the schema's depth or not. 246 """ 247 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 248 249 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 250 raise SchemaError( 251 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 252 f"schema's nesting level: {self.depth()}." 253 ) 254 255 normalized_column_mapping = { 256 self._normalize_name(key, dialect=dialect, normalize=normalize): value 257 for key, value in ensure_column_mapping(column_mapping).items() 258 } 259 260 schema = self.find(normalized_table, raise_on_missing=False) 261 if schema and not normalized_column_mapping: 262 return 263 264 parts = self.table_parts(normalized_table) 265 266 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 267 new_trie([parts], self.mapping_trie) 268 269 def column_names( 270 self, 271 table: exp.Table | str, 272 only_visible: bool = False, 273 dialect: DialectType = None, 274 normalize: t.Optional[bool] = None, 275 ) -> t.List[str]: 276 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 277 278 schema = self.find(normalized_table) 279 if schema is None: 280 return [] 281 282 if not only_visible or not self.visible: 283 return list(schema) 284 285 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 286 return [col for col in schema if col in visible] 287 288 def get_column_type( 289 self, 290 table: exp.Table | str, 291 column: exp.Column, 292 dialect: DialectType = None, 293 normalize: t.Optional[bool] = None, 294 ) -> exp.DataType: 295 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 296 297 normalized_column_name = self._normalize_name( 298 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 299 ) 300 301 table_schema = self.find(normalized_table, raise_on_missing=False) 302 if table_schema: 303 column_type = table_schema.get(normalized_column_name) 304 305 if isinstance(column_type, exp.DataType): 306 return column_type 307 elif isinstance(column_type, str): 308 return self._to_data_type(column_type.upper(), dialect=dialect) 309 310 return exp.DataType.build("unknown") 311 312 def _normalize(self, schema: t.Dict) -> t.Dict: 313 """ 314 Normalizes all identifiers in the schema. 315 316 Args: 317 schema: the schema to normalize. 318 319 Returns: 320 The normalized schema mapping. 321 """ 322 normalized_mapping: t.Dict = {} 323 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 324 325 for keys in flattened_schema: 326 columns = nested_get(schema, *zip(keys, keys)) 327 328 if not isinstance(columns, dict): 329 raise SchemaError( 330 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 331 ) 332 333 normalized_keys = [ 334 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 335 ] 336 for column_name, column_type in columns.items(): 337 nested_set( 338 normalized_mapping, 339 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 340 column_type, 341 ) 342 343 return normalized_mapping 344 345 def _normalize_table( 346 self, 347 table: exp.Table | str, 348 dialect: DialectType = None, 349 normalize: t.Optional[bool] = None, 350 ) -> exp.Table: 351 normalized_table = exp.maybe_parse( 352 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 353 ) 354 355 for arg in TABLE_ARGS: 356 value = normalized_table.args.get(arg) 357 if isinstance(value, (str, exp.Identifier)): 358 normalized_table.set( 359 arg, 360 exp.to_identifier( 361 self._normalize_name( 362 value, dialect=dialect, is_table=True, normalize=normalize 363 ) 364 ), 365 ) 366 367 return normalized_table 368 369 def _normalize_name( 370 self, 371 name: str | exp.Identifier, 372 dialect: DialectType = None, 373 is_table: bool = False, 374 normalize: t.Optional[bool] = None, 375 ) -> str: 376 return normalize_name( 377 name, 378 dialect=dialect or self.dialect, 379 is_table=is_table, 380 normalize=self.normalize if normalize is None else normalize, 381 ) 382 383 def depth(self) -> int: 384 if not self.empty and not self._depth: 385 # The columns themselves are a mapping, but we don't want to include those 386 self._depth = super().depth() - 1 387 return self._depth 388 389 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 390 """ 391 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 392 393 Args: 394 schema_type: the type we want to convert. 395 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 396 397 Returns: 398 The resulting expression type. 399 """ 400 if schema_type not in self._type_mapping_cache: 401 dialect = dialect or self.dialect 402 403 try: 404 expression = exp.DataType.build(schema_type, dialect=dialect) 405 self._type_mapping_cache[schema_type] = expression 406 except AttributeError: 407 in_dialect = f" in dialect {dialect}" if dialect else "" 408 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 409 410 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
193 def __init__( 194 self, 195 schema: t.Optional[t.Dict] = None, 196 visible: t.Optional[t.Dict] = None, 197 dialect: DialectType = None, 198 normalize: bool = True, 199 ) -> None: 200 self.dialect = dialect 201 self.visible = visible or {} 202 self.normalize = normalize 203 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 204 self._depth = 0 205 206 super().__init__(self._normalize(schema or {}))
413def normalize_name( 414 name: str | exp.Identifier, 415 dialect: DialectType = None, 416 is_table: bool = False, 417 normalize: t.Optional[bool] = True, 418) -> str: 419 try: 420 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 421 except ParseError: 422 return name if isinstance(name, str) else name.name 423 424 name = identifier.name 425 if not normalize: 426 return name 427 428 # This can be useful for normalize_identifier 429 identifier.meta["is_table"] = is_table 430 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
440def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 441 if mapping is None: 442 return {} 443 elif isinstance(mapping, dict): 444 return mapping 445 elif isinstance(mapping, str): 446 col_name_type_strs = [x.strip() for x in mapping.split(",")] 447 return { 448 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 449 for name_type_str in col_name_type_strs 450 } 451 # Check if mapping looks like a DataFrame StructType 452 elif hasattr(mapping, "simpleString"): 453 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 454 elif isinstance(mapping, list): 455 return {x.strip(): None for x in mapping} 456 457 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
460def flatten_schema( 461 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 462) -> t.List[t.List[str]]: 463 tables = [] 464 keys = keys or [] 465 466 for k, v in schema.items(): 467 if depth >= 2: 468 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 469 elif depth == 1: 470 tables.append(keys + [k]) 471 472 return tables
475def nested_get( 476 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 477) -> t.Optional[t.Any]: 478 """ 479 Get a value for a nested dictionary. 480 481 Args: 482 d: the dictionary to search. 483 *path: tuples of (name, key), where: 484 `key` is the key in the dictionary to get. 485 `name` is a string to use in the error if `key` isn't found. 486 487 Returns: 488 The value or None if it doesn't exist. 489 """ 490 for name, key in path: 491 d = d.get(key) # type: ignore 492 if d is None: 493 if raise_on_missing: 494 name = "table" if name == "this" else name 495 raise ValueError(f"Unknown {name}: {key}") 496 return None 497 498 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
501def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 502 """ 503 In-place set a value for a nested dictionary 504 505 Example: 506 >>> nested_set({}, ["top_key", "second_key"], "value") 507 {'top_key': {'second_key': 'value'}} 508 509 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 510 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 511 512 Args: 513 d: dictionary to update. 514 keys: the keys that makeup the path to `value`. 515 value: the value to set in the dictionary for the given key path. 516 517 Returns: 518 The (possibly) updated dictionary. 519 """ 520 if not keys: 521 return d 522 523 if len(keys) == 1: 524 d[keys[0]] = value 525 return d 526 527 subd = d 528 for key in keys[:-1]: 529 if key not in subd: 530 subd = subd.setdefault(key, {}) 531 else: 532 subd = subd[key] 533 534 subd[keys[-1]] = value 535 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.