sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot.dialects.dialect import Dialect 9from sqlglot.errors import ParseError, SchemaError 10from sqlglot.helper import dict_depth 11from sqlglot.trie import TrieResult, in_trie, new_trie 12 13if t.TYPE_CHECKING: 14 from sqlglot.dataframe.sql.types import StructType 15 from sqlglot.dialects.dialect import DialectType 16 17 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 18 19TABLE_ARGS = ("this", "db", "catalog") 20 21 22class Schema(abc.ABC): 23 """Abstract base class for database schemas""" 24 25 dialect: DialectType 26 27 @abc.abstractmethod 28 def add_table( 29 self, 30 table: exp.Table | str, 31 column_mapping: t.Optional[ColumnMapping] = None, 32 dialect: DialectType = None, 33 normalize: t.Optional[bool] = None, 34 match_depth: bool = True, 35 ) -> None: 36 """ 37 Register or update a table. Some implementing classes may require column information to also be provided. 38 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 39 40 Args: 41 table: the `Table` expression instance or string representing the table. 42 column_mapping: a column mapping that describes the structure of the table. 43 dialect: the SQL dialect that will be used to parse `table` if it's a string. 44 normalize: whether to normalize identifiers according to the dialect of interest. 45 match_depth: whether to enforce that the table must match the schema's depth or not. 46 """ 47 48 @abc.abstractmethod 49 def column_names( 50 self, 51 table: exp.Table | str, 52 only_visible: bool = False, 53 dialect: DialectType = None, 54 normalize: t.Optional[bool] = None, 55 ) -> t.List[str]: 56 """ 57 Get the column names for a table. 58 59 Args: 60 table: the `Table` expression instance. 61 only_visible: whether to include invisible columns. 62 dialect: the SQL dialect that will be used to parse `table` if it's a string. 63 normalize: whether to normalize identifiers according to the dialect of interest. 64 65 Returns: 66 The list of column names. 67 """ 68 69 @abc.abstractmethod 70 def get_column_type( 71 self, 72 table: exp.Table | str, 73 column: exp.Column | str, 74 dialect: DialectType = None, 75 normalize: t.Optional[bool] = None, 76 ) -> exp.DataType: 77 """ 78 Get the `sqlglot.exp.DataType` type of a column in the schema. 79 80 Args: 81 table: the source table. 82 column: the target column. 83 dialect: the SQL dialect that will be used to parse `table` if it's a string. 84 normalize: whether to normalize identifiers according to the dialect of interest. 85 86 Returns: 87 The resulting column type. 88 """ 89 90 def has_column( 91 self, 92 table: exp.Table | str, 93 column: exp.Column | str, 94 dialect: DialectType = None, 95 normalize: t.Optional[bool] = None, 96 ) -> bool: 97 """ 98 Returns whether or not `column` appears in `table`'s schema. 99 100 Args: 101 table: the source table. 102 column: the target column. 103 dialect: the SQL dialect that will be used to parse `table` if it's a string. 104 normalize: whether to normalize identifiers according to the dialect of interest. 105 106 Returns: 107 True if the column appears in the schema, False otherwise. 108 """ 109 name = column if isinstance(column, str) else column.name 110 return name in self.column_names(table, dialect=dialect, normalize=normalize) 111 112 @property 113 @abc.abstractmethod 114 def supported_table_args(self) -> t.Tuple[str, ...]: 115 """ 116 Table arguments this schema support, e.g. `("this", "db", "catalog")` 117 """ 118 119 @property 120 def empty(self) -> bool: 121 """Returns whether or not the schema is empty.""" 122 return True 123 124 125class AbstractMappingSchema: 126 def __init__( 127 self, 128 mapping: t.Optional[t.Dict] = None, 129 ) -> None: 130 self.mapping = mapping or {} 131 self.mapping_trie = new_trie( 132 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 133 ) 134 self._supported_table_args: t.Tuple[str, ...] = tuple() 135 136 @property 137 def empty(self) -> bool: 138 return not self.mapping 139 140 def depth(self) -> int: 141 return dict_depth(self.mapping) 142 143 @property 144 def supported_table_args(self) -> t.Tuple[str, ...]: 145 if not self._supported_table_args and self.mapping: 146 depth = self.depth() 147 148 if not depth: # None 149 self._supported_table_args = tuple() 150 elif 1 <= depth <= 3: 151 self._supported_table_args = TABLE_ARGS[:depth] 152 else: 153 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 154 155 return self._supported_table_args 156 157 def table_parts(self, table: exp.Table) -> t.List[str]: 158 if isinstance(table.this, exp.ReadCSV): 159 return [table.this.name] 160 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 161 162 def find( 163 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 164 ) -> t.Optional[t.Any]: 165 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 166 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 167 168 if value == TrieResult.FAILED: 169 return None 170 171 if value == TrieResult.PREFIX: 172 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 173 174 if len(possibilities) == 1: 175 parts.extend(possibilities[0]) 176 else: 177 message = ", ".join(".".join(parts) for parts in possibilities) 178 if raise_on_missing: 179 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 180 return None 181 182 return self.nested_get(parts, raise_on_missing=raise_on_missing) 183 184 def nested_get( 185 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 186 ) -> t.Optional[t.Any]: 187 return nested_get( 188 d or self.mapping, 189 *zip(self.supported_table_args, reversed(parts)), 190 raise_on_missing=raise_on_missing, 191 ) 192 193 194class MappingSchema(AbstractMappingSchema, Schema): 195 """ 196 Schema based on a nested mapping. 197 198 Args: 199 schema: Mapping in one of the following forms: 200 1. {table: {col: type}} 201 2. {db: {table: {col: type}}} 202 3. {catalog: {db: {table: {col: type}}}} 203 4. None - Tables will be added later 204 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 205 are assumed to be visible. The nesting should mirror that of the schema: 206 1. {table: set(*cols)}} 207 2. {db: {table: set(*cols)}}} 208 3. {catalog: {db: {table: set(*cols)}}}} 209 dialect: The dialect to be used for custom type mappings & parsing string arguments. 210 normalize: Whether to normalize identifier names according to the given dialect or not. 211 """ 212 213 def __init__( 214 self, 215 schema: t.Optional[t.Dict] = None, 216 visible: t.Optional[t.Dict] = None, 217 dialect: DialectType = None, 218 normalize: bool = True, 219 ) -> None: 220 self.dialect = dialect 221 self.visible = visible or {} 222 self.normalize = normalize 223 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 224 self._depth = 0 225 226 super().__init__(self._normalize(schema or {})) 227 228 @classmethod 229 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 230 return MappingSchema( 231 schema=mapping_schema.mapping, 232 visible=mapping_schema.visible, 233 dialect=mapping_schema.dialect, 234 normalize=mapping_schema.normalize, 235 ) 236 237 def copy(self, **kwargs) -> MappingSchema: 238 return MappingSchema( 239 **{ # type: ignore 240 "schema": self.mapping.copy(), 241 "visible": self.visible.copy(), 242 "dialect": self.dialect, 243 "normalize": self.normalize, 244 **kwargs, 245 } 246 ) 247 248 def add_table( 249 self, 250 table: exp.Table | str, 251 column_mapping: t.Optional[ColumnMapping] = None, 252 dialect: DialectType = None, 253 normalize: t.Optional[bool] = None, 254 match_depth: bool = True, 255 ) -> None: 256 """ 257 Register or update a table. Updates are only performed if a new column mapping is provided. 258 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 259 260 Args: 261 table: the `Table` expression instance or string representing the table. 262 column_mapping: a column mapping that describes the structure of the table. 263 dialect: the SQL dialect that will be used to parse `table` if it's a string. 264 normalize: whether to normalize identifiers according to the dialect of interest. 265 match_depth: whether to enforce that the table must match the schema's depth or not. 266 """ 267 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 268 269 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 270 raise SchemaError( 271 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 272 f"schema's nesting level: {self.depth()}." 273 ) 274 275 normalized_column_mapping = { 276 self._normalize_name(key, dialect=dialect, normalize=normalize): value 277 for key, value in ensure_column_mapping(column_mapping).items() 278 } 279 280 schema = self.find(normalized_table, raise_on_missing=False) 281 if schema and not normalized_column_mapping: 282 return 283 284 parts = self.table_parts(normalized_table) 285 286 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 287 new_trie([parts], self.mapping_trie) 288 289 def column_names( 290 self, 291 table: exp.Table | str, 292 only_visible: bool = False, 293 dialect: DialectType = None, 294 normalize: t.Optional[bool] = None, 295 ) -> t.List[str]: 296 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 297 298 schema = self.find(normalized_table) 299 if schema is None: 300 return [] 301 302 if not only_visible or not self.visible: 303 return list(schema) 304 305 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 306 return [col for col in schema if col in visible] 307 308 def get_column_type( 309 self, 310 table: exp.Table | str, 311 column: exp.Column | str, 312 dialect: DialectType = None, 313 normalize: t.Optional[bool] = None, 314 ) -> exp.DataType: 315 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 316 317 normalized_column_name = self._normalize_name( 318 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 319 ) 320 321 table_schema = self.find(normalized_table, raise_on_missing=False) 322 if table_schema: 323 column_type = table_schema.get(normalized_column_name) 324 325 if isinstance(column_type, exp.DataType): 326 return column_type 327 elif isinstance(column_type, str): 328 return self._to_data_type(column_type, dialect=dialect) 329 330 return exp.DataType.build("unknown") 331 332 def has_column( 333 self, 334 table: exp.Table | str, 335 column: exp.Column | str, 336 dialect: DialectType = None, 337 normalize: t.Optional[bool] = None, 338 ) -> bool: 339 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 340 341 normalized_column_name = self._normalize_name( 342 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 343 ) 344 345 table_schema = self.find(normalized_table, raise_on_missing=False) 346 return normalized_column_name in table_schema if table_schema else False 347 348 def _normalize(self, schema: t.Dict) -> t.Dict: 349 """ 350 Normalizes all identifiers in the schema. 351 352 Args: 353 schema: the schema to normalize. 354 355 Returns: 356 The normalized schema mapping. 357 """ 358 normalized_mapping: t.Dict = {} 359 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 360 361 for keys in flattened_schema: 362 columns = nested_get(schema, *zip(keys, keys)) 363 364 if not isinstance(columns, dict): 365 raise SchemaError( 366 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 367 ) 368 369 normalized_keys = [ 370 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 371 ] 372 for column_name, column_type in columns.items(): 373 nested_set( 374 normalized_mapping, 375 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 376 column_type, 377 ) 378 379 return normalized_mapping 380 381 def _normalize_table( 382 self, 383 table: exp.Table | str, 384 dialect: DialectType = None, 385 normalize: t.Optional[bool] = None, 386 ) -> exp.Table: 387 normalized_table = exp.maybe_parse( 388 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 389 ) 390 391 for arg in TABLE_ARGS: 392 value = normalized_table.args.get(arg) 393 if isinstance(value, (str, exp.Identifier)): 394 normalized_table.set( 395 arg, 396 exp.to_identifier( 397 self._normalize_name( 398 value, dialect=dialect, is_table=True, normalize=normalize 399 ) 400 ), 401 ) 402 403 return normalized_table 404 405 def _normalize_name( 406 self, 407 name: str | exp.Identifier, 408 dialect: DialectType = None, 409 is_table: bool = False, 410 normalize: t.Optional[bool] = None, 411 ) -> str: 412 return normalize_name( 413 name, 414 dialect=dialect or self.dialect, 415 is_table=is_table, 416 normalize=self.normalize if normalize is None else normalize, 417 ) 418 419 def depth(self) -> int: 420 if not self.empty and not self._depth: 421 # The columns themselves are a mapping, but we don't want to include those 422 self._depth = super().depth() - 1 423 return self._depth 424 425 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 426 """ 427 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 428 429 Args: 430 schema_type: the type we want to convert. 431 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 432 433 Returns: 434 The resulting expression type. 435 """ 436 if schema_type not in self._type_mapping_cache: 437 dialect = dialect or self.dialect 438 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 439 440 try: 441 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 442 self._type_mapping_cache[schema_type] = expression 443 except AttributeError: 444 in_dialect = f" in dialect {dialect}" if dialect else "" 445 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 446 447 return self._type_mapping_cache[schema_type] 448 449 450def normalize_name( 451 name: str | exp.Identifier, 452 dialect: DialectType = None, 453 is_table: bool = False, 454 normalize: t.Optional[bool] = True, 455) -> str: 456 try: 457 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 458 except ParseError: 459 return name if isinstance(name, str) else name.name 460 461 name = identifier.name 462 if not normalize: 463 return name 464 465 # This can be useful for normalize_identifier 466 identifier.meta["is_table"] = is_table 467 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 468 469 470def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 471 if isinstance(schema, Schema): 472 return schema 473 474 return MappingSchema(schema, **kwargs) 475 476 477def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 478 if mapping is None: 479 return {} 480 elif isinstance(mapping, dict): 481 return mapping 482 elif isinstance(mapping, str): 483 col_name_type_strs = [x.strip() for x in mapping.split(",")] 484 return { 485 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 486 for name_type_str in col_name_type_strs 487 } 488 # Check if mapping looks like a DataFrame StructType 489 elif hasattr(mapping, "simpleString"): 490 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 491 elif isinstance(mapping, list): 492 return {x.strip(): None for x in mapping} 493 494 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 495 496 497def flatten_schema( 498 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 499) -> t.List[t.List[str]]: 500 tables = [] 501 keys = keys or [] 502 503 for k, v in schema.items(): 504 if depth >= 2: 505 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 506 elif depth == 1: 507 tables.append(keys + [k]) 508 509 return tables 510 511 512def nested_get( 513 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 514) -> t.Optional[t.Any]: 515 """ 516 Get a value for a nested dictionary. 517 518 Args: 519 d: the dictionary to search. 520 *path: tuples of (name, key), where: 521 `key` is the key in the dictionary to get. 522 `name` is a string to use in the error if `key` isn't found. 523 524 Returns: 525 The value or None if it doesn't exist. 526 """ 527 for name, key in path: 528 d = d.get(key) # type: ignore 529 if d is None: 530 if raise_on_missing: 531 name = "table" if name == "this" else name 532 raise ValueError(f"Unknown {name}: {key}") 533 return None 534 535 return d 536 537 538def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 539 """ 540 In-place set a value for a nested dictionary 541 542 Example: 543 >>> nested_set({}, ["top_key", "second_key"], "value") 544 {'top_key': {'second_key': 'value'}} 545 546 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 547 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 548 549 Args: 550 d: dictionary to update. 551 keys: the keys that makeup the path to `value`. 552 value: the value to set in the dictionary for the given key path. 553 554 Returns: 555 The (possibly) updated dictionary. 556 """ 557 if not keys: 558 return d 559 560 if len(keys) == 1: 561 d[keys[0]] = value 562 return d 563 564 subd = d 565 for key in keys[:-1]: 566 if key not in subd: 567 subd = subd.setdefault(key, {}) 568 else: 569 subd = subd[key] 570 571 subd[keys[-1]] = value 572 return d
23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 match_depth: bool = True, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 40 41 Args: 42 table: the `Table` expression instance or string representing the table. 43 column_mapping: a column mapping that describes the structure of the table. 44 dialect: the SQL dialect that will be used to parse `table` if it's a string. 45 normalize: whether to normalize identifiers according to the dialect of interest. 46 match_depth: whether to enforce that the table must match the schema's depth or not. 47 """ 48 49 @abc.abstractmethod 50 def column_names( 51 self, 52 table: exp.Table | str, 53 only_visible: bool = False, 54 dialect: DialectType = None, 55 normalize: t.Optional[bool] = None, 56 ) -> t.List[str]: 57 """ 58 Get the column names for a table. 59 60 Args: 61 table: the `Table` expression instance. 62 only_visible: whether to include invisible columns. 63 dialect: the SQL dialect that will be used to parse `table` if it's a string. 64 normalize: whether to normalize identifiers according to the dialect of interest. 65 66 Returns: 67 The list of column names. 68 """ 69 70 @abc.abstractmethod 71 def get_column_type( 72 self, 73 table: exp.Table | str, 74 column: exp.Column | str, 75 dialect: DialectType = None, 76 normalize: t.Optional[bool] = None, 77 ) -> exp.DataType: 78 """ 79 Get the `sqlglot.exp.DataType` type of a column in the schema. 80 81 Args: 82 table: the source table. 83 column: the target column. 84 dialect: the SQL dialect that will be used to parse `table` if it's a string. 85 normalize: whether to normalize identifiers according to the dialect of interest. 86 87 Returns: 88 The resulting column type. 89 """ 90 91 def has_column( 92 self, 93 table: exp.Table | str, 94 column: exp.Column | str, 95 dialect: DialectType = None, 96 normalize: t.Optional[bool] = None, 97 ) -> bool: 98 """ 99 Returns whether or not `column` appears in `table`'s schema. 100 101 Args: 102 table: the source table. 103 column: the target column. 104 dialect: the SQL dialect that will be used to parse `table` if it's a string. 105 normalize: whether to normalize identifiers according to the dialect of interest. 106 107 Returns: 108 True if the column appears in the schema, False otherwise. 109 """ 110 name = column if isinstance(column, str) else column.name 111 return name in self.column_names(table, dialect=dialect, normalize=normalize) 112 113 @property 114 @abc.abstractmethod 115 def supported_table_args(self) -> t.Tuple[str, ...]: 116 """ 117 Table arguments this schema support, e.g. `("this", "db", "catalog")` 118 """ 119 120 @property 121 def empty(self) -> bool: 122 """Returns whether or not the schema is empty.""" 123 return True
Abstract base class for database schemas
28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 match_depth: bool = True, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 40 41 Args: 42 table: the `Table` expression instance or string representing the table. 43 column_mapping: a column mapping that describes the structure of the table. 44 dialect: the SQL dialect that will be used to parse `table` if it's a string. 45 normalize: whether to normalize identifiers according to the dialect of interest. 46 match_depth: whether to enforce that the table must match the schema's depth or not. 47 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
49 @abc.abstractmethod 50 def column_names( 51 self, 52 table: exp.Table | str, 53 only_visible: bool = False, 54 dialect: DialectType = None, 55 normalize: t.Optional[bool] = None, 56 ) -> t.List[str]: 57 """ 58 Get the column names for a table. 59 60 Args: 61 table: the `Table` expression instance. 62 only_visible: whether to include invisible columns. 63 dialect: the SQL dialect that will be used to parse `table` if it's a string. 64 normalize: whether to normalize identifiers according to the dialect of interest. 65 66 Returns: 67 The list of column names. 68 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
70 @abc.abstractmethod 71 def get_column_type( 72 self, 73 table: exp.Table | str, 74 column: exp.Column | str, 75 dialect: DialectType = None, 76 normalize: t.Optional[bool] = None, 77 ) -> exp.DataType: 78 """ 79 Get the `sqlglot.exp.DataType` type of a column in the schema. 80 81 Args: 82 table: the source table. 83 column: the target column. 84 dialect: the SQL dialect that will be used to parse `table` if it's a string. 85 normalize: whether to normalize identifiers according to the dialect of interest. 86 87 Returns: 88 The resulting column type. 89 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
91 def has_column( 92 self, 93 table: exp.Table | str, 94 column: exp.Column | str, 95 dialect: DialectType = None, 96 normalize: t.Optional[bool] = None, 97 ) -> bool: 98 """ 99 Returns whether or not `column` appears in `table`'s schema. 100 101 Args: 102 table: the source table. 103 column: the target column. 104 dialect: the SQL dialect that will be used to parse `table` if it's a string. 105 normalize: whether to normalize identifiers according to the dialect of interest. 106 107 Returns: 108 True if the column appears in the schema, False otherwise. 109 """ 110 name = column if isinstance(column, str) else column.name 111 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
126class AbstractMappingSchema: 127 def __init__( 128 self, 129 mapping: t.Optional[t.Dict] = None, 130 ) -> None: 131 self.mapping = mapping or {} 132 self.mapping_trie = new_trie( 133 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 134 ) 135 self._supported_table_args: t.Tuple[str, ...] = tuple() 136 137 @property 138 def empty(self) -> bool: 139 return not self.mapping 140 141 def depth(self) -> int: 142 return dict_depth(self.mapping) 143 144 @property 145 def supported_table_args(self) -> t.Tuple[str, ...]: 146 if not self._supported_table_args and self.mapping: 147 depth = self.depth() 148 149 if not depth: # None 150 self._supported_table_args = tuple() 151 elif 1 <= depth <= 3: 152 self._supported_table_args = TABLE_ARGS[:depth] 153 else: 154 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 155 156 return self._supported_table_args 157 158 def table_parts(self, table: exp.Table) -> t.List[str]: 159 if isinstance(table.this, exp.ReadCSV): 160 return [table.this.name] 161 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 162 163 def find( 164 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 165 ) -> t.Optional[t.Any]: 166 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 167 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 168 169 if value == TrieResult.FAILED: 170 return None 171 172 if value == TrieResult.PREFIX: 173 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 174 175 if len(possibilities) == 1: 176 parts.extend(possibilities[0]) 177 else: 178 message = ", ".join(".".join(parts) for parts in possibilities) 179 if raise_on_missing: 180 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 181 return None 182 183 return self.nested_get(parts, raise_on_missing=raise_on_missing) 184 185 def nested_get( 186 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 187 ) -> t.Optional[t.Any]: 188 return nested_get( 189 d or self.mapping, 190 *zip(self.supported_table_args, reversed(parts)), 191 raise_on_missing=raise_on_missing, 192 )
163 def find( 164 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 165 ) -> t.Optional[t.Any]: 166 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 167 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 168 169 if value == TrieResult.FAILED: 170 return None 171 172 if value == TrieResult.PREFIX: 173 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 174 175 if len(possibilities) == 1: 176 parts.extend(possibilities[0]) 177 else: 178 message = ", ".join(".".join(parts) for parts in possibilities) 179 if raise_on_missing: 180 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 181 return None 182 183 return self.nested_get(parts, raise_on_missing=raise_on_missing)
195class MappingSchema(AbstractMappingSchema, Schema): 196 """ 197 Schema based on a nested mapping. 198 199 Args: 200 schema: Mapping in one of the following forms: 201 1. {table: {col: type}} 202 2. {db: {table: {col: type}}} 203 3. {catalog: {db: {table: {col: type}}}} 204 4. None - Tables will be added later 205 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 206 are assumed to be visible. The nesting should mirror that of the schema: 207 1. {table: set(*cols)}} 208 2. {db: {table: set(*cols)}}} 209 3. {catalog: {db: {table: set(*cols)}}}} 210 dialect: The dialect to be used for custom type mappings & parsing string arguments. 211 normalize: Whether to normalize identifier names according to the given dialect or not. 212 """ 213 214 def __init__( 215 self, 216 schema: t.Optional[t.Dict] = None, 217 visible: t.Optional[t.Dict] = None, 218 dialect: DialectType = None, 219 normalize: bool = True, 220 ) -> None: 221 self.dialect = dialect 222 self.visible = visible or {} 223 self.normalize = normalize 224 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 225 self._depth = 0 226 227 super().__init__(self._normalize(schema or {})) 228 229 @classmethod 230 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 231 return MappingSchema( 232 schema=mapping_schema.mapping, 233 visible=mapping_schema.visible, 234 dialect=mapping_schema.dialect, 235 normalize=mapping_schema.normalize, 236 ) 237 238 def copy(self, **kwargs) -> MappingSchema: 239 return MappingSchema( 240 **{ # type: ignore 241 "schema": self.mapping.copy(), 242 "visible": self.visible.copy(), 243 "dialect": self.dialect, 244 "normalize": self.normalize, 245 **kwargs, 246 } 247 ) 248 249 def add_table( 250 self, 251 table: exp.Table | str, 252 column_mapping: t.Optional[ColumnMapping] = None, 253 dialect: DialectType = None, 254 normalize: t.Optional[bool] = None, 255 match_depth: bool = True, 256 ) -> None: 257 """ 258 Register or update a table. Updates are only performed if a new column mapping is provided. 259 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 260 261 Args: 262 table: the `Table` expression instance or string representing the table. 263 column_mapping: a column mapping that describes the structure of the table. 264 dialect: the SQL dialect that will be used to parse `table` if it's a string. 265 normalize: whether to normalize identifiers according to the dialect of interest. 266 match_depth: whether to enforce that the table must match the schema's depth or not. 267 """ 268 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 269 270 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 271 raise SchemaError( 272 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 273 f"schema's nesting level: {self.depth()}." 274 ) 275 276 normalized_column_mapping = { 277 self._normalize_name(key, dialect=dialect, normalize=normalize): value 278 for key, value in ensure_column_mapping(column_mapping).items() 279 } 280 281 schema = self.find(normalized_table, raise_on_missing=False) 282 if schema and not normalized_column_mapping: 283 return 284 285 parts = self.table_parts(normalized_table) 286 287 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 288 new_trie([parts], self.mapping_trie) 289 290 def column_names( 291 self, 292 table: exp.Table | str, 293 only_visible: bool = False, 294 dialect: DialectType = None, 295 normalize: t.Optional[bool] = None, 296 ) -> t.List[str]: 297 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 298 299 schema = self.find(normalized_table) 300 if schema is None: 301 return [] 302 303 if not only_visible or not self.visible: 304 return list(schema) 305 306 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 307 return [col for col in schema if col in visible] 308 309 def get_column_type( 310 self, 311 table: exp.Table | str, 312 column: exp.Column | str, 313 dialect: DialectType = None, 314 normalize: t.Optional[bool] = None, 315 ) -> exp.DataType: 316 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 317 318 normalized_column_name = self._normalize_name( 319 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 320 ) 321 322 table_schema = self.find(normalized_table, raise_on_missing=False) 323 if table_schema: 324 column_type = table_schema.get(normalized_column_name) 325 326 if isinstance(column_type, exp.DataType): 327 return column_type 328 elif isinstance(column_type, str): 329 return self._to_data_type(column_type, dialect=dialect) 330 331 return exp.DataType.build("unknown") 332 333 def has_column( 334 self, 335 table: exp.Table | str, 336 column: exp.Column | str, 337 dialect: DialectType = None, 338 normalize: t.Optional[bool] = None, 339 ) -> bool: 340 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 341 342 normalized_column_name = self._normalize_name( 343 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 344 ) 345 346 table_schema = self.find(normalized_table, raise_on_missing=False) 347 return normalized_column_name in table_schema if table_schema else False 348 349 def _normalize(self, schema: t.Dict) -> t.Dict: 350 """ 351 Normalizes all identifiers in the schema. 352 353 Args: 354 schema: the schema to normalize. 355 356 Returns: 357 The normalized schema mapping. 358 """ 359 normalized_mapping: t.Dict = {} 360 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 361 362 for keys in flattened_schema: 363 columns = nested_get(schema, *zip(keys, keys)) 364 365 if not isinstance(columns, dict): 366 raise SchemaError( 367 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 368 ) 369 370 normalized_keys = [ 371 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 372 ] 373 for column_name, column_type in columns.items(): 374 nested_set( 375 normalized_mapping, 376 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 377 column_type, 378 ) 379 380 return normalized_mapping 381 382 def _normalize_table( 383 self, 384 table: exp.Table | str, 385 dialect: DialectType = None, 386 normalize: t.Optional[bool] = None, 387 ) -> exp.Table: 388 normalized_table = exp.maybe_parse( 389 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 390 ) 391 392 for arg in TABLE_ARGS: 393 value = normalized_table.args.get(arg) 394 if isinstance(value, (str, exp.Identifier)): 395 normalized_table.set( 396 arg, 397 exp.to_identifier( 398 self._normalize_name( 399 value, dialect=dialect, is_table=True, normalize=normalize 400 ) 401 ), 402 ) 403 404 return normalized_table 405 406 def _normalize_name( 407 self, 408 name: str | exp.Identifier, 409 dialect: DialectType = None, 410 is_table: bool = False, 411 normalize: t.Optional[bool] = None, 412 ) -> str: 413 return normalize_name( 414 name, 415 dialect=dialect or self.dialect, 416 is_table=is_table, 417 normalize=self.normalize if normalize is None else normalize, 418 ) 419 420 def depth(self) -> int: 421 if not self.empty and not self._depth: 422 # The columns themselves are a mapping, but we don't want to include those 423 self._depth = super().depth() - 1 424 return self._depth 425 426 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 427 """ 428 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 429 430 Args: 431 schema_type: the type we want to convert. 432 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 433 434 Returns: 435 The resulting expression type. 436 """ 437 if schema_type not in self._type_mapping_cache: 438 dialect = dialect or self.dialect 439 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 440 441 try: 442 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 443 self._type_mapping_cache[schema_type] = expression 444 except AttributeError: 445 in_dialect = f" in dialect {dialect}" if dialect else "" 446 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 447 448 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
214 def __init__( 215 self, 216 schema: t.Optional[t.Dict] = None, 217 visible: t.Optional[t.Dict] = None, 218 dialect: DialectType = None, 219 normalize: bool = True, 220 ) -> None: 221 self.dialect = dialect 222 self.visible = visible or {} 223 self.normalize = normalize 224 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 225 self._depth = 0 226 227 super().__init__(self._normalize(schema or {}))
249 def add_table( 250 self, 251 table: exp.Table | str, 252 column_mapping: t.Optional[ColumnMapping] = None, 253 dialect: DialectType = None, 254 normalize: t.Optional[bool] = None, 255 match_depth: bool = True, 256 ) -> None: 257 """ 258 Register or update a table. Updates are only performed if a new column mapping is provided. 259 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 260 261 Args: 262 table: the `Table` expression instance or string representing the table. 263 column_mapping: a column mapping that describes the structure of the table. 264 dialect: the SQL dialect that will be used to parse `table` if it's a string. 265 normalize: whether to normalize identifiers according to the dialect of interest. 266 match_depth: whether to enforce that the table must match the schema's depth or not. 267 """ 268 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 269 270 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 271 raise SchemaError( 272 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 273 f"schema's nesting level: {self.depth()}." 274 ) 275 276 normalized_column_mapping = { 277 self._normalize_name(key, dialect=dialect, normalize=normalize): value 278 for key, value in ensure_column_mapping(column_mapping).items() 279 } 280 281 schema = self.find(normalized_table, raise_on_missing=False) 282 if schema and not normalized_column_mapping: 283 return 284 285 parts = self.table_parts(normalized_table) 286 287 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 288 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
290 def column_names( 291 self, 292 table: exp.Table | str, 293 only_visible: bool = False, 294 dialect: DialectType = None, 295 normalize: t.Optional[bool] = None, 296 ) -> t.List[str]: 297 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 298 299 schema = self.find(normalized_table) 300 if schema is None: 301 return [] 302 303 if not only_visible or not self.visible: 304 return list(schema) 305 306 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 307 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
309 def get_column_type( 310 self, 311 table: exp.Table | str, 312 column: exp.Column | str, 313 dialect: DialectType = None, 314 normalize: t.Optional[bool] = None, 315 ) -> exp.DataType: 316 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 317 318 normalized_column_name = self._normalize_name( 319 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 320 ) 321 322 table_schema = self.find(normalized_table, raise_on_missing=False) 323 if table_schema: 324 column_type = table_schema.get(normalized_column_name) 325 326 if isinstance(column_type, exp.DataType): 327 return column_type 328 elif isinstance(column_type, str): 329 return self._to_data_type(column_type, dialect=dialect) 330 331 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
333 def has_column( 334 self, 335 table: exp.Table | str, 336 column: exp.Column | str, 337 dialect: DialectType = None, 338 normalize: t.Optional[bool] = None, 339 ) -> bool: 340 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 341 342 normalized_column_name = self._normalize_name( 343 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 344 ) 345 346 table_schema = self.find(normalized_table, raise_on_missing=False) 347 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
451def normalize_name( 452 name: str | exp.Identifier, 453 dialect: DialectType = None, 454 is_table: bool = False, 455 normalize: t.Optional[bool] = True, 456) -> str: 457 try: 458 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 459 except ParseError: 460 return name if isinstance(name, str) else name.name 461 462 name = identifier.name 463 if not normalize: 464 return name 465 466 # This can be useful for normalize_identifier 467 identifier.meta["is_table"] = is_table 468 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
478def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 479 if mapping is None: 480 return {} 481 elif isinstance(mapping, dict): 482 return mapping 483 elif isinstance(mapping, str): 484 col_name_type_strs = [x.strip() for x in mapping.split(",")] 485 return { 486 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 487 for name_type_str in col_name_type_strs 488 } 489 # Check if mapping looks like a DataFrame StructType 490 elif hasattr(mapping, "simpleString"): 491 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 492 elif isinstance(mapping, list): 493 return {x.strip(): None for x in mapping} 494 495 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
498def flatten_schema( 499 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 500) -> t.List[t.List[str]]: 501 tables = [] 502 keys = keys or [] 503 504 for k, v in schema.items(): 505 if depth >= 2: 506 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 507 elif depth == 1: 508 tables.append(keys + [k]) 509 510 return tables
513def nested_get( 514 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 515) -> t.Optional[t.Any]: 516 """ 517 Get a value for a nested dictionary. 518 519 Args: 520 d: the dictionary to search. 521 *path: tuples of (name, key), where: 522 `key` is the key in the dictionary to get. 523 `name` is a string to use in the error if `key` isn't found. 524 525 Returns: 526 The value or None if it doesn't exist. 527 """ 528 for name, key in path: 529 d = d.get(key) # type: ignore 530 if d is None: 531 if raise_on_missing: 532 name = "table" if name == "this" else name 533 raise ValueError(f"Unknown {name}: {key}") 534 return None 535 536 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
539def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 540 """ 541 In-place set a value for a nested dictionary 542 543 Example: 544 >>> nested_set({}, ["top_key", "second_key"], "value") 545 {'top_key': {'second_key': 'value'}} 546 547 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 548 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 549 550 Args: 551 d: dictionary to update. 552 keys: the keys that makeup the path to `value`. 553 value: the value to set in the dictionary for the given key path. 554 555 Returns: 556 The (possibly) updated dictionary. 557 """ 558 if not keys: 559 return d 560 561 if len(keys) == 1: 562 d[keys[0]] = value 563 return d 564 565 subd = d 566 for key in keys[:-1]: 567 if key not in subd: 568 subd = subd.setdefault(key, {}) 569 else: 570 subd = subd[key] 571 572 subd[keys[-1]] = value 573 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.