sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot._typing import T 9from sqlglot.dialects.dialect import Dialect 10from sqlglot.errors import ParseError, SchemaError 11from sqlglot.helper import dict_depth 12from sqlglot.trie import TrieResult, in_trie, new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot.dataframe.sql.types import StructType 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 19 20TABLE_ARGS = ("this", "db", "catalog") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 match_depth: bool = True, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 40 41 Args: 42 table: the `Table` expression instance or string representing the table. 43 column_mapping: a column mapping that describes the structure of the table. 44 dialect: the SQL dialect that will be used to parse `table` if it's a string. 45 normalize: whether to normalize identifiers according to the dialect of interest. 46 match_depth: whether to enforce that the table must match the schema's depth or not. 47 """ 48 49 @abc.abstractmethod 50 def column_names( 51 self, 52 table: exp.Table | str, 53 only_visible: bool = False, 54 dialect: DialectType = None, 55 normalize: t.Optional[bool] = None, 56 ) -> t.List[str]: 57 """ 58 Get the column names for a table. 59 60 Args: 61 table: the `Table` expression instance. 62 only_visible: whether to include invisible columns. 63 dialect: the SQL dialect that will be used to parse `table` if it's a string. 64 normalize: whether to normalize identifiers according to the dialect of interest. 65 66 Returns: 67 The list of column names. 68 """ 69 70 @abc.abstractmethod 71 def get_column_type( 72 self, 73 table: exp.Table | str, 74 column: exp.Column, 75 dialect: DialectType = None, 76 normalize: t.Optional[bool] = None, 77 ) -> exp.DataType: 78 """ 79 Get the `sqlglot.exp.DataType` type of a column in the schema. 80 81 Args: 82 table: the source table. 83 column: the target column. 84 dialect: the SQL dialect that will be used to parse `table` if it's a string. 85 normalize: whether to normalize identifiers according to the dialect of interest. 86 87 Returns: 88 The resulting column type. 89 """ 90 91 @property 92 @abc.abstractmethod 93 def supported_table_args(self) -> t.Tuple[str, ...]: 94 """ 95 Table arguments this schema support, e.g. `("this", "db", "catalog")` 96 """ 97 98 @property 99 def empty(self) -> bool: 100 """Returns whether or not the schema is empty.""" 101 return True 102 103 104class AbstractMappingSchema(t.Generic[T]): 105 def __init__( 106 self, 107 mapping: t.Optional[t.Dict] = None, 108 ) -> None: 109 self.mapping = mapping or {} 110 self.mapping_trie = new_trie( 111 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 112 ) 113 self._supported_table_args: t.Tuple[str, ...] = tuple() 114 115 @property 116 def empty(self) -> bool: 117 return not self.mapping 118 119 def depth(self) -> int: 120 return dict_depth(self.mapping) 121 122 @property 123 def supported_table_args(self) -> t.Tuple[str, ...]: 124 if not self._supported_table_args and self.mapping: 125 depth = self.depth() 126 127 if not depth: # None 128 self._supported_table_args = tuple() 129 elif 1 <= depth <= 3: 130 self._supported_table_args = TABLE_ARGS[:depth] 131 else: 132 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 133 134 return self._supported_table_args 135 136 def table_parts(self, table: exp.Table) -> t.List[str]: 137 if isinstance(table.this, exp.ReadCSV): 138 return [table.this.name] 139 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 140 141 def find( 142 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 143 ) -> t.Optional[T]: 144 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 145 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 146 147 if value == TrieResult.FAILED: 148 return None 149 150 if value == TrieResult.PREFIX: 151 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 152 153 if len(possibilities) == 1: 154 parts.extend(possibilities[0]) 155 else: 156 message = ", ".join(".".join(parts) for parts in possibilities) 157 if raise_on_missing: 158 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 159 return None 160 161 return self.nested_get(parts, raise_on_missing=raise_on_missing) 162 163 def nested_get( 164 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 165 ) -> t.Optional[t.Any]: 166 return nested_get( 167 d or self.mapping, 168 *zip(self.supported_table_args, reversed(parts)), 169 raise_on_missing=raise_on_missing, 170 ) 171 172 173class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 174 """ 175 Schema based on a nested mapping. 176 177 Args: 178 schema: Mapping in one of the following forms: 179 1. {table: {col: type}} 180 2. {db: {table: {col: type}}} 181 3. {catalog: {db: {table: {col: type}}}} 182 4. None - Tables will be added later 183 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 184 are assumed to be visible. The nesting should mirror that of the schema: 185 1. {table: set(*cols)}} 186 2. {db: {table: set(*cols)}}} 187 3. {catalog: {db: {table: set(*cols)}}}} 188 dialect: The dialect to be used for custom type mappings & parsing string arguments. 189 normalize: Whether to normalize identifier names according to the given dialect or not. 190 """ 191 192 def __init__( 193 self, 194 schema: t.Optional[t.Dict] = None, 195 visible: t.Optional[t.Dict] = None, 196 dialect: DialectType = None, 197 normalize: bool = True, 198 ) -> None: 199 self.dialect = dialect 200 self.visible = visible or {} 201 self.normalize = normalize 202 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 203 self._depth = 0 204 205 super().__init__(self._normalize(schema or {})) 206 207 @classmethod 208 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 209 return MappingSchema( 210 schema=mapping_schema.mapping, 211 visible=mapping_schema.visible, 212 dialect=mapping_schema.dialect, 213 normalize=mapping_schema.normalize, 214 ) 215 216 def copy(self, **kwargs) -> MappingSchema: 217 return MappingSchema( 218 **{ # type: ignore 219 "schema": self.mapping.copy(), 220 "visible": self.visible.copy(), 221 "dialect": self.dialect, 222 "normalize": self.normalize, 223 **kwargs, 224 } 225 ) 226 227 def add_table( 228 self, 229 table: exp.Table | str, 230 column_mapping: t.Optional[ColumnMapping] = None, 231 dialect: DialectType = None, 232 normalize: t.Optional[bool] = None, 233 match_depth: bool = True, 234 ) -> None: 235 """ 236 Register or update a table. Updates are only performed if a new column mapping is provided. 237 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 238 239 Args: 240 table: the `Table` expression instance or string representing the table. 241 column_mapping: a column mapping that describes the structure of the table. 242 dialect: the SQL dialect that will be used to parse `table` if it's a string. 243 normalize: whether to normalize identifiers according to the dialect of interest. 244 match_depth: whether to enforce that the table must match the schema's depth or not. 245 """ 246 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 247 248 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 249 raise SchemaError( 250 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 251 f"schema's nesting level: {self.depth()}." 252 ) 253 254 normalized_column_mapping = { 255 self._normalize_name(key, dialect=dialect, normalize=normalize): value 256 for key, value in ensure_column_mapping(column_mapping).items() 257 } 258 259 schema = self.find(normalized_table, raise_on_missing=False) 260 if schema and not normalized_column_mapping: 261 return 262 263 parts = self.table_parts(normalized_table) 264 265 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 266 new_trie([parts], self.mapping_trie) 267 268 def column_names( 269 self, 270 table: exp.Table | str, 271 only_visible: bool = False, 272 dialect: DialectType = None, 273 normalize: t.Optional[bool] = None, 274 ) -> t.List[str]: 275 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 276 277 schema = self.find(normalized_table) 278 if schema is None: 279 return [] 280 281 if not only_visible or not self.visible: 282 return list(schema) 283 284 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 285 return [col for col in schema if col in visible] 286 287 def get_column_type( 288 self, 289 table: exp.Table | str, 290 column: exp.Column, 291 dialect: DialectType = None, 292 normalize: t.Optional[bool] = None, 293 ) -> exp.DataType: 294 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 295 296 normalized_column_name = self._normalize_name( 297 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 298 ) 299 300 table_schema = self.find(normalized_table, raise_on_missing=False) 301 if table_schema: 302 column_type = table_schema.get(normalized_column_name) 303 304 if isinstance(column_type, exp.DataType): 305 return column_type 306 elif isinstance(column_type, str): 307 return self._to_data_type(column_type.upper(), dialect=dialect) 308 309 return exp.DataType.build("unknown") 310 311 def _normalize(self, schema: t.Dict) -> t.Dict: 312 """ 313 Normalizes all identifiers in the schema. 314 315 Args: 316 schema: the schema to normalize. 317 318 Returns: 319 The normalized schema mapping. 320 """ 321 normalized_mapping: t.Dict = {} 322 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 323 324 for keys in flattened_schema: 325 columns = nested_get(schema, *zip(keys, keys)) 326 327 if not isinstance(columns, dict): 328 raise SchemaError( 329 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 330 ) 331 332 normalized_keys = [ 333 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 334 ] 335 for column_name, column_type in columns.items(): 336 nested_set( 337 normalized_mapping, 338 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 339 column_type, 340 ) 341 342 return normalized_mapping 343 344 def _normalize_table( 345 self, 346 table: exp.Table | str, 347 dialect: DialectType = None, 348 normalize: t.Optional[bool] = None, 349 ) -> exp.Table: 350 normalized_table = exp.maybe_parse( 351 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 352 ) 353 354 for arg in TABLE_ARGS: 355 value = normalized_table.args.get(arg) 356 if isinstance(value, (str, exp.Identifier)): 357 normalized_table.set( 358 arg, 359 exp.to_identifier( 360 self._normalize_name( 361 value, dialect=dialect, is_table=True, normalize=normalize 362 ) 363 ), 364 ) 365 366 return normalized_table 367 368 def _normalize_name( 369 self, 370 name: str | exp.Identifier, 371 dialect: DialectType = None, 372 is_table: bool = False, 373 normalize: t.Optional[bool] = None, 374 ) -> str: 375 dialect = dialect or self.dialect 376 normalize = self.normalize if normalize is None else normalize 377 378 try: 379 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 380 except ParseError: 381 return name if isinstance(name, str) else name.name 382 383 name = identifier.name 384 if not normalize: 385 return name 386 387 # This can be useful for normalize_identifier 388 identifier.meta["is_table"] = is_table 389 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 390 391 def depth(self) -> int: 392 if not self.empty and not self._depth: 393 # The columns themselves are a mapping, but we don't want to include those 394 self._depth = super().depth() - 1 395 return self._depth 396 397 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 398 """ 399 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 400 401 Args: 402 schema_type: the type we want to convert. 403 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 404 405 Returns: 406 The resulting expression type. 407 """ 408 if schema_type not in self._type_mapping_cache: 409 dialect = dialect or self.dialect 410 411 try: 412 expression = exp.DataType.build(schema_type, dialect=dialect) 413 self._type_mapping_cache[schema_type] = expression 414 except AttributeError: 415 in_dialect = f" in dialect {dialect}" if dialect else "" 416 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 417 418 return self._type_mapping_cache[schema_type] 419 420 421def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 422 if isinstance(schema, Schema): 423 return schema 424 425 return MappingSchema(schema, **kwargs) 426 427 428def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 429 if mapping is None: 430 return {} 431 elif isinstance(mapping, dict): 432 return mapping 433 elif isinstance(mapping, str): 434 col_name_type_strs = [x.strip() for x in mapping.split(",")] 435 return { 436 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 437 for name_type_str in col_name_type_strs 438 } 439 # Check if mapping looks like a DataFrame StructType 440 elif hasattr(mapping, "simpleString"): 441 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 442 elif isinstance(mapping, list): 443 return {x.strip(): None for x in mapping} 444 445 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 446 447 448def flatten_schema( 449 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 450) -> t.List[t.List[str]]: 451 tables = [] 452 keys = keys or [] 453 454 for k, v in schema.items(): 455 if depth >= 2: 456 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 457 elif depth == 1: 458 tables.append(keys + [k]) 459 460 return tables 461 462 463def nested_get( 464 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 465) -> t.Optional[t.Any]: 466 """ 467 Get a value for a nested dictionary. 468 469 Args: 470 d: the dictionary to search. 471 *path: tuples of (name, key), where: 472 `key` is the key in the dictionary to get. 473 `name` is a string to use in the error if `key` isn't found. 474 475 Returns: 476 The value or None if it doesn't exist. 477 """ 478 for name, key in path: 479 d = d.get(key) # type: ignore 480 if d is None: 481 if raise_on_missing: 482 name = "table" if name == "this" else name 483 raise ValueError(f"Unknown {name}: {key}") 484 return None 485 486 return d 487 488 489def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 490 """ 491 In-place set a value for a nested dictionary 492 493 Example: 494 >>> nested_set({}, ["top_key", "second_key"], "value") 495 {'top_key': {'second_key': 'value'}} 496 497 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 498 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 499 500 Args: 501 d: dictionary to update. 502 keys: the keys that makeup the path to `value`. 503 value: the value to set in the dictionary for the given key path. 504 505 Returns: 506 The (possibly) updated dictionary. 507 """ 508 if not keys: 509 return d 510 511 if len(keys) == 1: 512 d[keys[0]] = value 513 return d 514 515 subd = d 516 for key in keys[:-1]: 517 if key not in subd: 518 subd = subd.setdefault(key, {}) 519 else: 520 subd = subd[key] 521 522 subd[keys[-1]] = value 523 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 dialect: DialectType 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """ 49 50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """ 70 71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """ 91 92 @property 93 @abc.abstractmethod 94 def supported_table_args(self) -> t.Tuple[str, ...]: 95 """ 96 Table arguments this schema support, e.g. `("this", "db", "catalog")` 97 """ 98 99 @property 100 def empty(self) -> bool: 101 """Returns whether or not the schema is empty.""" 102 return True
Abstract base class for database schemas
29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
105class AbstractMappingSchema(t.Generic[T]): 106 def __init__( 107 self, 108 mapping: t.Optional[t.Dict] = None, 109 ) -> None: 110 self.mapping = mapping or {} 111 self.mapping_trie = new_trie( 112 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 113 ) 114 self._supported_table_args: t.Tuple[str, ...] = tuple() 115 116 @property 117 def empty(self) -> bool: 118 return not self.mapping 119 120 def depth(self) -> int: 121 return dict_depth(self.mapping) 122 123 @property 124 def supported_table_args(self) -> t.Tuple[str, ...]: 125 if not self._supported_table_args and self.mapping: 126 depth = self.depth() 127 128 if not depth: # None 129 self._supported_table_args = tuple() 130 elif 1 <= depth <= 3: 131 self._supported_table_args = TABLE_ARGS[:depth] 132 else: 133 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 134 135 return self._supported_table_args 136 137 def table_parts(self, table: exp.Table) -> t.List[str]: 138 if isinstance(table.this, exp.ReadCSV): 139 return [table.this.name] 140 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 141 142 def find( 143 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 144 ) -> t.Optional[T]: 145 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 146 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 147 148 if value == TrieResult.FAILED: 149 return None 150 151 if value == TrieResult.PREFIX: 152 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 153 154 if len(possibilities) == 1: 155 parts.extend(possibilities[0]) 156 else: 157 message = ", ".join(".".join(parts) for parts in possibilities) 158 if raise_on_missing: 159 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 160 return None 161 162 return self.nested_get(parts, raise_on_missing=raise_on_missing) 163 164 def nested_get( 165 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 166 ) -> t.Optional[t.Any]: 167 return nested_get( 168 d or self.mapping, 169 *zip(self.supported_table_args, reversed(parts)), 170 raise_on_missing=raise_on_missing, 171 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
142 def find( 143 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 144 ) -> t.Optional[T]: 145 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 146 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 147 148 if value == TrieResult.FAILED: 149 return None 150 151 if value == TrieResult.PREFIX: 152 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 153 154 if len(possibilities) == 1: 155 parts.extend(possibilities[0]) 156 else: 157 message = ", ".join(".".join(parts) for parts in possibilities) 158 if raise_on_missing: 159 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 160 return None 161 162 return self.nested_get(parts, raise_on_missing=raise_on_missing)
174class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 175 """ 176 Schema based on a nested mapping. 177 178 Args: 179 schema: Mapping in one of the following forms: 180 1. {table: {col: type}} 181 2. {db: {table: {col: type}}} 182 3. {catalog: {db: {table: {col: type}}}} 183 4. None - Tables will be added later 184 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 185 are assumed to be visible. The nesting should mirror that of the schema: 186 1. {table: set(*cols)}} 187 2. {db: {table: set(*cols)}}} 188 3. {catalog: {db: {table: set(*cols)}}}} 189 dialect: The dialect to be used for custom type mappings & parsing string arguments. 190 normalize: Whether to normalize identifier names according to the given dialect or not. 191 """ 192 193 def __init__( 194 self, 195 schema: t.Optional[t.Dict] = None, 196 visible: t.Optional[t.Dict] = None, 197 dialect: DialectType = None, 198 normalize: bool = True, 199 ) -> None: 200 self.dialect = dialect 201 self.visible = visible or {} 202 self.normalize = normalize 203 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 204 self._depth = 0 205 206 super().__init__(self._normalize(schema or {})) 207 208 @classmethod 209 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 210 return MappingSchema( 211 schema=mapping_schema.mapping, 212 visible=mapping_schema.visible, 213 dialect=mapping_schema.dialect, 214 normalize=mapping_schema.normalize, 215 ) 216 217 def copy(self, **kwargs) -> MappingSchema: 218 return MappingSchema( 219 **{ # type: ignore 220 "schema": self.mapping.copy(), 221 "visible": self.visible.copy(), 222 "dialect": self.dialect, 223 "normalize": self.normalize, 224 **kwargs, 225 } 226 ) 227 228 def add_table( 229 self, 230 table: exp.Table | str, 231 column_mapping: t.Optional[ColumnMapping] = None, 232 dialect: DialectType = None, 233 normalize: t.Optional[bool] = None, 234 match_depth: bool = True, 235 ) -> None: 236 """ 237 Register or update a table. Updates are only performed if a new column mapping is provided. 238 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 239 240 Args: 241 table: the `Table` expression instance or string representing the table. 242 column_mapping: a column mapping that describes the structure of the table. 243 dialect: the SQL dialect that will be used to parse `table` if it's a string. 244 normalize: whether to normalize identifiers according to the dialect of interest. 245 match_depth: whether to enforce that the table must match the schema's depth or not. 246 """ 247 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 248 249 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 250 raise SchemaError( 251 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 252 f"schema's nesting level: {self.depth()}." 253 ) 254 255 normalized_column_mapping = { 256 self._normalize_name(key, dialect=dialect, normalize=normalize): value 257 for key, value in ensure_column_mapping(column_mapping).items() 258 } 259 260 schema = self.find(normalized_table, raise_on_missing=False) 261 if schema and not normalized_column_mapping: 262 return 263 264 parts = self.table_parts(normalized_table) 265 266 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 267 new_trie([parts], self.mapping_trie) 268 269 def column_names( 270 self, 271 table: exp.Table | str, 272 only_visible: bool = False, 273 dialect: DialectType = None, 274 normalize: t.Optional[bool] = None, 275 ) -> t.List[str]: 276 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 277 278 schema = self.find(normalized_table) 279 if schema is None: 280 return [] 281 282 if not only_visible or not self.visible: 283 return list(schema) 284 285 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 286 return [col for col in schema if col in visible] 287 288 def get_column_type( 289 self, 290 table: exp.Table | str, 291 column: exp.Column, 292 dialect: DialectType = None, 293 normalize: t.Optional[bool] = None, 294 ) -> exp.DataType: 295 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 296 297 normalized_column_name = self._normalize_name( 298 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 299 ) 300 301 table_schema = self.find(normalized_table, raise_on_missing=False) 302 if table_schema: 303 column_type = table_schema.get(normalized_column_name) 304 305 if isinstance(column_type, exp.DataType): 306 return column_type 307 elif isinstance(column_type, str): 308 return self._to_data_type(column_type.upper(), dialect=dialect) 309 310 return exp.DataType.build("unknown") 311 312 def _normalize(self, schema: t.Dict) -> t.Dict: 313 """ 314 Normalizes all identifiers in the schema. 315 316 Args: 317 schema: the schema to normalize. 318 319 Returns: 320 The normalized schema mapping. 321 """ 322 normalized_mapping: t.Dict = {} 323 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 324 325 for keys in flattened_schema: 326 columns = nested_get(schema, *zip(keys, keys)) 327 328 if not isinstance(columns, dict): 329 raise SchemaError( 330 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 331 ) 332 333 normalized_keys = [ 334 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 335 ] 336 for column_name, column_type in columns.items(): 337 nested_set( 338 normalized_mapping, 339 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 340 column_type, 341 ) 342 343 return normalized_mapping 344 345 def _normalize_table( 346 self, 347 table: exp.Table | str, 348 dialect: DialectType = None, 349 normalize: t.Optional[bool] = None, 350 ) -> exp.Table: 351 normalized_table = exp.maybe_parse( 352 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 353 ) 354 355 for arg in TABLE_ARGS: 356 value = normalized_table.args.get(arg) 357 if isinstance(value, (str, exp.Identifier)): 358 normalized_table.set( 359 arg, 360 exp.to_identifier( 361 self._normalize_name( 362 value, dialect=dialect, is_table=True, normalize=normalize 363 ) 364 ), 365 ) 366 367 return normalized_table 368 369 def _normalize_name( 370 self, 371 name: str | exp.Identifier, 372 dialect: DialectType = None, 373 is_table: bool = False, 374 normalize: t.Optional[bool] = None, 375 ) -> str: 376 dialect = dialect or self.dialect 377 normalize = self.normalize if normalize is None else normalize 378 379 try: 380 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 381 except ParseError: 382 return name if isinstance(name, str) else name.name 383 384 name = identifier.name 385 if not normalize: 386 return name 387 388 # This can be useful for normalize_identifier 389 identifier.meta["is_table"] = is_table 390 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 391 392 def depth(self) -> int: 393 if not self.empty and not self._depth: 394 # The columns themselves are a mapping, but we don't want to include those 395 self._depth = super().depth() - 1 396 return self._depth 397 398 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 399 """ 400 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 401 402 Args: 403 schema_type: the type we want to convert. 404 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 405 406 Returns: 407 The resulting expression type. 408 """ 409 if schema_type not in self._type_mapping_cache: 410 dialect = dialect or self.dialect 411 412 try: 413 expression = exp.DataType.build(schema_type, dialect=dialect) 414 self._type_mapping_cache[schema_type] = expression 415 except AttributeError: 416 in_dialect = f" in dialect {dialect}" if dialect else "" 417 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 418 419 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
193 def __init__( 194 self, 195 schema: t.Optional[t.Dict] = None, 196 visible: t.Optional[t.Dict] = None, 197 dialect: DialectType = None, 198 normalize: bool = True, 199 ) -> None: 200 self.dialect = dialect 201 self.visible = visible or {} 202 self.normalize = normalize 203 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 204 self._depth = 0 205 206 super().__init__(self._normalize(schema or {}))
429def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 430 if mapping is None: 431 return {} 432 elif isinstance(mapping, dict): 433 return mapping 434 elif isinstance(mapping, str): 435 col_name_type_strs = [x.strip() for x in mapping.split(",")] 436 return { 437 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 438 for name_type_str in col_name_type_strs 439 } 440 # Check if mapping looks like a DataFrame StructType 441 elif hasattr(mapping, "simpleString"): 442 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 443 elif isinstance(mapping, list): 444 return {x.strip(): None for x in mapping} 445 446 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
449def flatten_schema( 450 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 451) -> t.List[t.List[str]]: 452 tables = [] 453 keys = keys or [] 454 455 for k, v in schema.items(): 456 if depth >= 2: 457 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 458 elif depth == 1: 459 tables.append(keys + [k]) 460 461 return tables
464def nested_get( 465 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 466) -> t.Optional[t.Any]: 467 """ 468 Get a value for a nested dictionary. 469 470 Args: 471 d: the dictionary to search. 472 *path: tuples of (name, key), where: 473 `key` is the key in the dictionary to get. 474 `name` is a string to use in the error if `key` isn't found. 475 476 Returns: 477 The value or None if it doesn't exist. 478 """ 479 for name, key in path: 480 d = d.get(key) # type: ignore 481 if d is None: 482 if raise_on_missing: 483 name = "table" if name == "this" else name 484 raise ValueError(f"Unknown {name}: {key}") 485 return None 486 487 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
490def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 491 """ 492 In-place set a value for a nested dictionary 493 494 Example: 495 >>> nested_set({}, ["top_key", "second_key"], "value") 496 {'top_key': {'second_key': 'value'}} 497 498 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 499 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 500 501 Args: 502 d: dictionary to update. 503 keys: the keys that makeup the path to `value`. 504 value: the value to set in the dictionary for the given key path. 505 506 Returns: 507 The (possibly) updated dictionary. 508 """ 509 if not keys: 510 return d 511 512 if len(keys) == 1: 513 d[keys[0]] = value 514 return d 515 516 subd = d 517 for key in keys[:-1]: 518 if key not in subd: 519 subd = subd.setdefault(key, {}) 520 else: 521 subd = subd[key] 522 523 subd[keys[-1]] = value 524 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.