sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot._typing import T 9from sqlglot.dialects.dialect import Dialect 10from sqlglot.errors import ParseError, SchemaError 11from sqlglot.helper import dict_depth 12from sqlglot.trie import TrieResult, in_trie, new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot.dataframe.sql.types import StructType 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 19 20TABLE_ARGS = ("this", "db", "catalog") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 match_depth: bool = True, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 40 41 Args: 42 table: the `Table` expression instance or string representing the table. 43 column_mapping: a column mapping that describes the structure of the table. 44 dialect: the SQL dialect that will be used to parse `table` if it's a string. 45 normalize: whether to normalize identifiers according to the dialect of interest. 46 match_depth: whether to enforce that the table must match the schema's depth or not. 47 """ 48 49 @abc.abstractmethod 50 def column_names( 51 self, 52 table: exp.Table | str, 53 only_visible: bool = False, 54 dialect: DialectType = None, 55 normalize: t.Optional[bool] = None, 56 ) -> t.List[str]: 57 """ 58 Get the column names for a table. 59 60 Args: 61 table: the `Table` expression instance. 62 only_visible: whether to include invisible columns. 63 dialect: the SQL dialect that will be used to parse `table` if it's a string. 64 normalize: whether to normalize identifiers according to the dialect of interest. 65 66 Returns: 67 The list of column names. 68 """ 69 70 @abc.abstractmethod 71 def get_column_type( 72 self, 73 table: exp.Table | str, 74 column: exp.Column | str, 75 dialect: DialectType = None, 76 normalize: t.Optional[bool] = None, 77 ) -> exp.DataType: 78 """ 79 Get the `sqlglot.exp.DataType` type of a column in the schema. 80 81 Args: 82 table: the source table. 83 column: the target column. 84 dialect: the SQL dialect that will be used to parse `table` if it's a string. 85 normalize: whether to normalize identifiers according to the dialect of interest. 86 87 Returns: 88 The resulting column type. 89 """ 90 91 def has_column( 92 self, 93 table: exp.Table | str, 94 column: exp.Column | str, 95 dialect: DialectType = None, 96 normalize: t.Optional[bool] = None, 97 ) -> bool: 98 """ 99 Returns whether or not `column` appears in `table`'s schema. 100 101 Args: 102 table: the source table. 103 column: the target column. 104 dialect: the SQL dialect that will be used to parse `table` if it's a string. 105 normalize: whether to normalize identifiers according to the dialect of interest. 106 107 Returns: 108 True if the column appears in the schema, False otherwise. 109 """ 110 name = column if isinstance(column, str) else column.name 111 return name in self.column_names(table, dialect=dialect, normalize=normalize) 112 113 @property 114 @abc.abstractmethod 115 def supported_table_args(self) -> t.Tuple[str, ...]: 116 """ 117 Table arguments this schema support, e.g. `("this", "db", "catalog")` 118 """ 119 120 @property 121 def empty(self) -> bool: 122 """Returns whether or not the schema is empty.""" 123 return True 124 125 126class AbstractMappingSchema(t.Generic[T]): 127 def __init__( 128 self, 129 mapping: t.Optional[t.Dict] = None, 130 ) -> None: 131 self.mapping = mapping or {} 132 self.mapping_trie = new_trie( 133 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 134 ) 135 self._supported_table_args: t.Tuple[str, ...] = tuple() 136 137 @property 138 def empty(self) -> bool: 139 return not self.mapping 140 141 def depth(self) -> int: 142 return dict_depth(self.mapping) 143 144 @property 145 def supported_table_args(self) -> t.Tuple[str, ...]: 146 if not self._supported_table_args and self.mapping: 147 depth = self.depth() 148 149 if not depth: # None 150 self._supported_table_args = tuple() 151 elif 1 <= depth <= 3: 152 self._supported_table_args = TABLE_ARGS[:depth] 153 else: 154 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 155 156 return self._supported_table_args 157 158 def table_parts(self, table: exp.Table) -> t.List[str]: 159 if isinstance(table.this, exp.ReadCSV): 160 return [table.this.name] 161 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 162 163 def find( 164 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 165 ) -> t.Optional[T]: 166 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 167 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 168 169 if value == TrieResult.FAILED: 170 return None 171 172 if value == TrieResult.PREFIX: 173 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 174 175 if len(possibilities) == 1: 176 parts.extend(possibilities[0]) 177 else: 178 message = ", ".join(".".join(parts) for parts in possibilities) 179 if raise_on_missing: 180 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 181 return None 182 183 return self.nested_get(parts, raise_on_missing=raise_on_missing) 184 185 def nested_get( 186 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 187 ) -> t.Optional[t.Any]: 188 return nested_get( 189 d or self.mapping, 190 *zip(self.supported_table_args, reversed(parts)), 191 raise_on_missing=raise_on_missing, 192 ) 193 194 195class MappingSchema(AbstractMappingSchema, Schema): 196 """ 197 Schema based on a nested mapping. 198 199 Args: 200 schema: Mapping in one of the following forms: 201 1. {table: {col: type}} 202 2. {db: {table: {col: type}}} 203 3. {catalog: {db: {table: {col: type}}}} 204 4. None - Tables will be added later 205 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 206 are assumed to be visible. The nesting should mirror that of the schema: 207 1. {table: set(*cols)}} 208 2. {db: {table: set(*cols)}}} 209 3. {catalog: {db: {table: set(*cols)}}}} 210 dialect: The dialect to be used for custom type mappings & parsing string arguments. 211 normalize: Whether to normalize identifier names according to the given dialect or not. 212 """ 213 214 def __init__( 215 self, 216 schema: t.Optional[t.Dict] = None, 217 visible: t.Optional[t.Dict] = None, 218 dialect: DialectType = None, 219 normalize: bool = True, 220 ) -> None: 221 self.dialect = dialect 222 self.visible = visible or {} 223 self.normalize = normalize 224 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 225 self._depth = 0 226 227 super().__init__(self._normalize(schema or {})) 228 229 @classmethod 230 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 231 return MappingSchema( 232 schema=mapping_schema.mapping, 233 visible=mapping_schema.visible, 234 dialect=mapping_schema.dialect, 235 normalize=mapping_schema.normalize, 236 ) 237 238 def copy(self, **kwargs) -> MappingSchema: 239 return MappingSchema( 240 **{ # type: ignore 241 "schema": self.mapping.copy(), 242 "visible": self.visible.copy(), 243 "dialect": self.dialect, 244 "normalize": self.normalize, 245 **kwargs, 246 } 247 ) 248 249 def add_table( 250 self, 251 table: exp.Table | str, 252 column_mapping: t.Optional[ColumnMapping] = None, 253 dialect: DialectType = None, 254 normalize: t.Optional[bool] = None, 255 match_depth: bool = True, 256 ) -> None: 257 """ 258 Register or update a table. Updates are only performed if a new column mapping is provided. 259 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 260 261 Args: 262 table: the `Table` expression instance or string representing the table. 263 column_mapping: a column mapping that describes the structure of the table. 264 dialect: the SQL dialect that will be used to parse `table` if it's a string. 265 normalize: whether to normalize identifiers according to the dialect of interest. 266 match_depth: whether to enforce that the table must match the schema's depth or not. 267 """ 268 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 269 270 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 271 raise SchemaError( 272 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 273 f"schema's nesting level: {self.depth()}." 274 ) 275 276 normalized_column_mapping = { 277 self._normalize_name(key, dialect=dialect, normalize=normalize): value 278 for key, value in ensure_column_mapping(column_mapping).items() 279 } 280 281 schema = self.find(normalized_table, raise_on_missing=False) 282 if schema and not normalized_column_mapping: 283 return 284 285 parts = self.table_parts(normalized_table) 286 287 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 288 new_trie([parts], self.mapping_trie) 289 290 def column_names( 291 self, 292 table: exp.Table | str, 293 only_visible: bool = False, 294 dialect: DialectType = None, 295 normalize: t.Optional[bool] = None, 296 ) -> t.List[str]: 297 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 298 299 schema = self.find(normalized_table) 300 if schema is None: 301 return [] 302 303 if not only_visible or not self.visible: 304 return list(schema) 305 306 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 307 return [col for col in schema if col in visible] 308 309 def get_column_type( 310 self, 311 table: exp.Table | str, 312 column: exp.Column | str, 313 dialect: DialectType = None, 314 normalize: t.Optional[bool] = None, 315 ) -> exp.DataType: 316 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 317 318 normalized_column_name = self._normalize_name( 319 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 320 ) 321 322 table_schema = self.find(normalized_table, raise_on_missing=False) 323 if table_schema: 324 column_type = table_schema.get(normalized_column_name) 325 326 if isinstance(column_type, exp.DataType): 327 return column_type 328 elif isinstance(column_type, str): 329 return self._to_data_type(column_type, dialect=dialect) 330 331 return exp.DataType.build("unknown") 332 333 def has_column( 334 self, 335 table: exp.Table | str, 336 column: exp.Column | str, 337 dialect: DialectType = None, 338 normalize: t.Optional[bool] = None, 339 ) -> bool: 340 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 341 342 normalized_column_name = self._normalize_name( 343 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 344 ) 345 346 table_schema = self.find(normalized_table, raise_on_missing=False) 347 return normalized_column_name in table_schema if table_schema else False 348 349 def _normalize(self, schema: t.Dict) -> t.Dict: 350 """ 351 Normalizes all identifiers in the schema. 352 353 Args: 354 schema: the schema to normalize. 355 356 Returns: 357 The normalized schema mapping. 358 """ 359 normalized_mapping: t.Dict = {} 360 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 361 362 for keys in flattened_schema: 363 columns = nested_get(schema, *zip(keys, keys)) 364 365 if not isinstance(columns, dict): 366 raise SchemaError( 367 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 368 ) 369 370 normalized_keys = [ 371 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 372 ] 373 for column_name, column_type in columns.items(): 374 nested_set( 375 normalized_mapping, 376 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 377 column_type, 378 ) 379 380 return normalized_mapping 381 382 def _normalize_table( 383 self, 384 table: exp.Table | str, 385 dialect: DialectType = None, 386 normalize: t.Optional[bool] = None, 387 ) -> exp.Table: 388 normalized_table = exp.maybe_parse( 389 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 390 ) 391 392 for arg in TABLE_ARGS: 393 value = normalized_table.args.get(arg) 394 if isinstance(value, (str, exp.Identifier)): 395 normalized_table.set( 396 arg, 397 exp.to_identifier( 398 self._normalize_name( 399 value, dialect=dialect, is_table=True, normalize=normalize 400 ) 401 ), 402 ) 403 404 return normalized_table 405 406 def _normalize_name( 407 self, 408 name: str | exp.Identifier, 409 dialect: DialectType = None, 410 is_table: bool = False, 411 normalize: t.Optional[bool] = None, 412 ) -> str: 413 return normalize_name( 414 name, 415 dialect=dialect or self.dialect, 416 is_table=is_table, 417 normalize=self.normalize if normalize is None else normalize, 418 ) 419 420 def depth(self) -> int: 421 if not self.empty and not self._depth: 422 # The columns themselves are a mapping, but we don't want to include those 423 self._depth = super().depth() - 1 424 return self._depth 425 426 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 427 """ 428 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 429 430 Args: 431 schema_type: the type we want to convert. 432 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 433 434 Returns: 435 The resulting expression type. 436 """ 437 if schema_type not in self._type_mapping_cache: 438 dialect = dialect or self.dialect 439 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 440 441 try: 442 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 443 self._type_mapping_cache[schema_type] = expression 444 except AttributeError: 445 in_dialect = f" in dialect {dialect}" if dialect else "" 446 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 447 448 return self._type_mapping_cache[schema_type] 449 450 451def normalize_name( 452 name: str | exp.Identifier, 453 dialect: DialectType = None, 454 is_table: bool = False, 455 normalize: t.Optional[bool] = True, 456) -> str: 457 try: 458 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 459 except ParseError: 460 return name if isinstance(name, str) else name.name 461 462 name = identifier.name 463 if not normalize: 464 return name 465 466 # This can be useful for normalize_identifier 467 identifier.meta["is_table"] = is_table 468 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 469 470 471def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 472 if isinstance(schema, Schema): 473 return schema 474 475 return MappingSchema(schema, **kwargs) 476 477 478def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 479 if mapping is None: 480 return {} 481 elif isinstance(mapping, dict): 482 return mapping 483 elif isinstance(mapping, str): 484 col_name_type_strs = [x.strip() for x in mapping.split(",")] 485 return { 486 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 487 for name_type_str in col_name_type_strs 488 } 489 # Check if mapping looks like a DataFrame StructType 490 elif hasattr(mapping, "simpleString"): 491 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 492 elif isinstance(mapping, list): 493 return {x.strip(): None for x in mapping} 494 495 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 496 497 498def flatten_schema( 499 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 500) -> t.List[t.List[str]]: 501 tables = [] 502 keys = keys or [] 503 504 for k, v in schema.items(): 505 if depth >= 2: 506 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 507 elif depth == 1: 508 tables.append(keys + [k]) 509 510 return tables 511 512 513def nested_get( 514 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 515) -> t.Optional[t.Any]: 516 """ 517 Get a value for a nested dictionary. 518 519 Args: 520 d: the dictionary to search. 521 *path: tuples of (name, key), where: 522 `key` is the key in the dictionary to get. 523 `name` is a string to use in the error if `key` isn't found. 524 525 Returns: 526 The value or None if it doesn't exist. 527 """ 528 for name, key in path: 529 d = d.get(key) # type: ignore 530 if d is None: 531 if raise_on_missing: 532 name = "table" if name == "this" else name 533 raise ValueError(f"Unknown {name}: {key}") 534 return None 535 536 return d 537 538 539def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 540 """ 541 In-place set a value for a nested dictionary 542 543 Example: 544 >>> nested_set({}, ["top_key", "second_key"], "value") 545 {'top_key': {'second_key': 'value'}} 546 547 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 548 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 549 550 Args: 551 d: dictionary to update. 552 keys: the keys that makeup the path to `value`. 553 value: the value to set in the dictionary for the given key path. 554 555 Returns: 556 The (possibly) updated dictionary. 557 """ 558 if not keys: 559 return d 560 561 if len(keys) == 1: 562 d[keys[0]] = value 563 return d 564 565 subd = d 566 for key in keys[:-1]: 567 if key not in subd: 568 subd = subd.setdefault(key, {}) 569 else: 570 subd = subd[key] 571 572 subd[keys[-1]] = value 573 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 dialect: DialectType 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """ 49 50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """ 70 71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column | str, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """ 91 92 def has_column( 93 self, 94 table: exp.Table | str, 95 column: exp.Column | str, 96 dialect: DialectType = None, 97 normalize: t.Optional[bool] = None, 98 ) -> bool: 99 """ 100 Returns whether or not `column` appears in `table`'s schema. 101 102 Args: 103 table: the source table. 104 column: the target column. 105 dialect: the SQL dialect that will be used to parse `table` if it's a string. 106 normalize: whether to normalize identifiers according to the dialect of interest. 107 108 Returns: 109 True if the column appears in the schema, False otherwise. 110 """ 111 name = column if isinstance(column, str) else column.name 112 return name in self.column_names(table, dialect=dialect, normalize=normalize) 113 114 @property 115 @abc.abstractmethod 116 def supported_table_args(self) -> t.Tuple[str, ...]: 117 """ 118 Table arguments this schema support, e.g. `("this", "db", "catalog")` 119 """ 120 121 @property 122 def empty(self) -> bool: 123 """Returns whether or not the schema is empty.""" 124 return True
Abstract base class for database schemas
29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.List[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The list of column names. 69 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column | str, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
92 def has_column( 93 self, 94 table: exp.Table | str, 95 column: exp.Column | str, 96 dialect: DialectType = None, 97 normalize: t.Optional[bool] = None, 98 ) -> bool: 99 """ 100 Returns whether or not `column` appears in `table`'s schema. 101 102 Args: 103 table: the source table. 104 column: the target column. 105 dialect: the SQL dialect that will be used to parse `table` if it's a string. 106 normalize: whether to normalize identifiers according to the dialect of interest. 107 108 Returns: 109 True if the column appears in the schema, False otherwise. 110 """ 111 name = column if isinstance(column, str) else column.name 112 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
127class AbstractMappingSchema(t.Generic[T]): 128 def __init__( 129 self, 130 mapping: t.Optional[t.Dict] = None, 131 ) -> None: 132 self.mapping = mapping or {} 133 self.mapping_trie = new_trie( 134 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 135 ) 136 self._supported_table_args: t.Tuple[str, ...] = tuple() 137 138 @property 139 def empty(self) -> bool: 140 return not self.mapping 141 142 def depth(self) -> int: 143 return dict_depth(self.mapping) 144 145 @property 146 def supported_table_args(self) -> t.Tuple[str, ...]: 147 if not self._supported_table_args and self.mapping: 148 depth = self.depth() 149 150 if not depth: # None 151 self._supported_table_args = tuple() 152 elif 1 <= depth <= 3: 153 self._supported_table_args = TABLE_ARGS[:depth] 154 else: 155 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 156 157 return self._supported_table_args 158 159 def table_parts(self, table: exp.Table) -> t.List[str]: 160 if isinstance(table.this, exp.ReadCSV): 161 return [table.this.name] 162 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 163 164 def find( 165 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 166 ) -> t.Optional[T]: 167 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 168 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 169 170 if value == TrieResult.FAILED: 171 return None 172 173 if value == TrieResult.PREFIX: 174 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 175 176 if len(possibilities) == 1: 177 parts.extend(possibilities[0]) 178 else: 179 message = ", ".join(".".join(parts) for parts in possibilities) 180 if raise_on_missing: 181 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 182 return None 183 184 return self.nested_get(parts, raise_on_missing=raise_on_missing) 185 186 def nested_get( 187 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 188 ) -> t.Optional[t.Any]: 189 return nested_get( 190 d or self.mapping, 191 *zip(self.supported_table_args, reversed(parts)), 192 raise_on_missing=raise_on_missing, 193 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
164 def find( 165 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 166 ) -> t.Optional[T]: 167 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 168 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 169 170 if value == TrieResult.FAILED: 171 return None 172 173 if value == TrieResult.PREFIX: 174 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 175 176 if len(possibilities) == 1: 177 parts.extend(possibilities[0]) 178 else: 179 message = ", ".join(".".join(parts) for parts in possibilities) 180 if raise_on_missing: 181 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 182 return None 183 184 return self.nested_get(parts, raise_on_missing=raise_on_missing)
196class MappingSchema(AbstractMappingSchema, Schema): 197 """ 198 Schema based on a nested mapping. 199 200 Args: 201 schema: Mapping in one of the following forms: 202 1. {table: {col: type}} 203 2. {db: {table: {col: type}}} 204 3. {catalog: {db: {table: {col: type}}}} 205 4. None - Tables will be added later 206 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 207 are assumed to be visible. The nesting should mirror that of the schema: 208 1. {table: set(*cols)}} 209 2. {db: {table: set(*cols)}}} 210 3. {catalog: {db: {table: set(*cols)}}}} 211 dialect: The dialect to be used for custom type mappings & parsing string arguments. 212 normalize: Whether to normalize identifier names according to the given dialect or not. 213 """ 214 215 def __init__( 216 self, 217 schema: t.Optional[t.Dict] = None, 218 visible: t.Optional[t.Dict] = None, 219 dialect: DialectType = None, 220 normalize: bool = True, 221 ) -> None: 222 self.dialect = dialect 223 self.visible = visible or {} 224 self.normalize = normalize 225 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 226 self._depth = 0 227 228 super().__init__(self._normalize(schema or {})) 229 230 @classmethod 231 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 232 return MappingSchema( 233 schema=mapping_schema.mapping, 234 visible=mapping_schema.visible, 235 dialect=mapping_schema.dialect, 236 normalize=mapping_schema.normalize, 237 ) 238 239 def copy(self, **kwargs) -> MappingSchema: 240 return MappingSchema( 241 **{ # type: ignore 242 "schema": self.mapping.copy(), 243 "visible": self.visible.copy(), 244 "dialect": self.dialect, 245 "normalize": self.normalize, 246 **kwargs, 247 } 248 ) 249 250 def add_table( 251 self, 252 table: exp.Table | str, 253 column_mapping: t.Optional[ColumnMapping] = None, 254 dialect: DialectType = None, 255 normalize: t.Optional[bool] = None, 256 match_depth: bool = True, 257 ) -> None: 258 """ 259 Register or update a table. Updates are only performed if a new column mapping is provided. 260 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 261 262 Args: 263 table: the `Table` expression instance or string representing the table. 264 column_mapping: a column mapping that describes the structure of the table. 265 dialect: the SQL dialect that will be used to parse `table` if it's a string. 266 normalize: whether to normalize identifiers according to the dialect of interest. 267 match_depth: whether to enforce that the table must match the schema's depth or not. 268 """ 269 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 270 271 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 272 raise SchemaError( 273 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 274 f"schema's nesting level: {self.depth()}." 275 ) 276 277 normalized_column_mapping = { 278 self._normalize_name(key, dialect=dialect, normalize=normalize): value 279 for key, value in ensure_column_mapping(column_mapping).items() 280 } 281 282 schema = self.find(normalized_table, raise_on_missing=False) 283 if schema and not normalized_column_mapping: 284 return 285 286 parts = self.table_parts(normalized_table) 287 288 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 289 new_trie([parts], self.mapping_trie) 290 291 def column_names( 292 self, 293 table: exp.Table | str, 294 only_visible: bool = False, 295 dialect: DialectType = None, 296 normalize: t.Optional[bool] = None, 297 ) -> t.List[str]: 298 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 299 300 schema = self.find(normalized_table) 301 if schema is None: 302 return [] 303 304 if not only_visible or not self.visible: 305 return list(schema) 306 307 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 308 return [col for col in schema if col in visible] 309 310 def get_column_type( 311 self, 312 table: exp.Table | str, 313 column: exp.Column | str, 314 dialect: DialectType = None, 315 normalize: t.Optional[bool] = None, 316 ) -> exp.DataType: 317 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 318 319 normalized_column_name = self._normalize_name( 320 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 321 ) 322 323 table_schema = self.find(normalized_table, raise_on_missing=False) 324 if table_schema: 325 column_type = table_schema.get(normalized_column_name) 326 327 if isinstance(column_type, exp.DataType): 328 return column_type 329 elif isinstance(column_type, str): 330 return self._to_data_type(column_type, dialect=dialect) 331 332 return exp.DataType.build("unknown") 333 334 def has_column( 335 self, 336 table: exp.Table | str, 337 column: exp.Column | str, 338 dialect: DialectType = None, 339 normalize: t.Optional[bool] = None, 340 ) -> bool: 341 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 342 343 normalized_column_name = self._normalize_name( 344 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 345 ) 346 347 table_schema = self.find(normalized_table, raise_on_missing=False) 348 return normalized_column_name in table_schema if table_schema else False 349 350 def _normalize(self, schema: t.Dict) -> t.Dict: 351 """ 352 Normalizes all identifiers in the schema. 353 354 Args: 355 schema: the schema to normalize. 356 357 Returns: 358 The normalized schema mapping. 359 """ 360 normalized_mapping: t.Dict = {} 361 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 362 363 for keys in flattened_schema: 364 columns = nested_get(schema, *zip(keys, keys)) 365 366 if not isinstance(columns, dict): 367 raise SchemaError( 368 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 369 ) 370 371 normalized_keys = [ 372 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 373 ] 374 for column_name, column_type in columns.items(): 375 nested_set( 376 normalized_mapping, 377 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 378 column_type, 379 ) 380 381 return normalized_mapping 382 383 def _normalize_table( 384 self, 385 table: exp.Table | str, 386 dialect: DialectType = None, 387 normalize: t.Optional[bool] = None, 388 ) -> exp.Table: 389 normalized_table = exp.maybe_parse( 390 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 391 ) 392 393 for arg in TABLE_ARGS: 394 value = normalized_table.args.get(arg) 395 if isinstance(value, (str, exp.Identifier)): 396 normalized_table.set( 397 arg, 398 exp.to_identifier( 399 self._normalize_name( 400 value, dialect=dialect, is_table=True, normalize=normalize 401 ) 402 ), 403 ) 404 405 return normalized_table 406 407 def _normalize_name( 408 self, 409 name: str | exp.Identifier, 410 dialect: DialectType = None, 411 is_table: bool = False, 412 normalize: t.Optional[bool] = None, 413 ) -> str: 414 return normalize_name( 415 name, 416 dialect=dialect or self.dialect, 417 is_table=is_table, 418 normalize=self.normalize if normalize is None else normalize, 419 ) 420 421 def depth(self) -> int: 422 if not self.empty and not self._depth: 423 # The columns themselves are a mapping, but we don't want to include those 424 self._depth = super().depth() - 1 425 return self._depth 426 427 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 428 """ 429 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 430 431 Args: 432 schema_type: the type we want to convert. 433 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 434 435 Returns: 436 The resulting expression type. 437 """ 438 if schema_type not in self._type_mapping_cache: 439 dialect = dialect or self.dialect 440 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 441 442 try: 443 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 444 self._type_mapping_cache[schema_type] = expression 445 except AttributeError: 446 in_dialect = f" in dialect {dialect}" if dialect else "" 447 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 448 449 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
215 def __init__( 216 self, 217 schema: t.Optional[t.Dict] = None, 218 visible: t.Optional[t.Dict] = None, 219 dialect: DialectType = None, 220 normalize: bool = True, 221 ) -> None: 222 self.dialect = dialect 223 self.visible = visible or {} 224 self.normalize = normalize 225 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 226 self._depth = 0 227 228 super().__init__(self._normalize(schema or {}))
250 def add_table( 251 self, 252 table: exp.Table | str, 253 column_mapping: t.Optional[ColumnMapping] = None, 254 dialect: DialectType = None, 255 normalize: t.Optional[bool] = None, 256 match_depth: bool = True, 257 ) -> None: 258 """ 259 Register or update a table. Updates are only performed if a new column mapping is provided. 260 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 261 262 Args: 263 table: the `Table` expression instance or string representing the table. 264 column_mapping: a column mapping that describes the structure of the table. 265 dialect: the SQL dialect that will be used to parse `table` if it's a string. 266 normalize: whether to normalize identifiers according to the dialect of interest. 267 match_depth: whether to enforce that the table must match the schema's depth or not. 268 """ 269 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 270 271 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 272 raise SchemaError( 273 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 274 f"schema's nesting level: {self.depth()}." 275 ) 276 277 normalized_column_mapping = { 278 self._normalize_name(key, dialect=dialect, normalize=normalize): value 279 for key, value in ensure_column_mapping(column_mapping).items() 280 } 281 282 schema = self.find(normalized_table, raise_on_missing=False) 283 if schema and not normalized_column_mapping: 284 return 285 286 parts = self.table_parts(normalized_table) 287 288 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 289 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
291 def column_names( 292 self, 293 table: exp.Table | str, 294 only_visible: bool = False, 295 dialect: DialectType = None, 296 normalize: t.Optional[bool] = None, 297 ) -> t.List[str]: 298 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 299 300 schema = self.find(normalized_table) 301 if schema is None: 302 return [] 303 304 if not only_visible or not self.visible: 305 return list(schema) 306 307 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 308 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
310 def get_column_type( 311 self, 312 table: exp.Table | str, 313 column: exp.Column | str, 314 dialect: DialectType = None, 315 normalize: t.Optional[bool] = None, 316 ) -> exp.DataType: 317 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 318 319 normalized_column_name = self._normalize_name( 320 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 321 ) 322 323 table_schema = self.find(normalized_table, raise_on_missing=False) 324 if table_schema: 325 column_type = table_schema.get(normalized_column_name) 326 327 if isinstance(column_type, exp.DataType): 328 return column_type 329 elif isinstance(column_type, str): 330 return self._to_data_type(column_type, dialect=dialect) 331 332 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
334 def has_column( 335 self, 336 table: exp.Table | str, 337 column: exp.Column | str, 338 dialect: DialectType = None, 339 normalize: t.Optional[bool] = None, 340 ) -> bool: 341 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 342 343 normalized_column_name = self._normalize_name( 344 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 345 ) 346 347 table_schema = self.find(normalized_table, raise_on_missing=False) 348 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
452def normalize_name( 453 name: str | exp.Identifier, 454 dialect: DialectType = None, 455 is_table: bool = False, 456 normalize: t.Optional[bool] = True, 457) -> str: 458 try: 459 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 460 except ParseError: 461 return name if isinstance(name, str) else name.name 462 463 name = identifier.name 464 if not normalize: 465 return name 466 467 # This can be useful for normalize_identifier 468 identifier.meta["is_table"] = is_table 469 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
479def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 480 if mapping is None: 481 return {} 482 elif isinstance(mapping, dict): 483 return mapping 484 elif isinstance(mapping, str): 485 col_name_type_strs = [x.strip() for x in mapping.split(",")] 486 return { 487 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 488 for name_type_str in col_name_type_strs 489 } 490 # Check if mapping looks like a DataFrame StructType 491 elif hasattr(mapping, "simpleString"): 492 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 493 elif isinstance(mapping, list): 494 return {x.strip(): None for x in mapping} 495 496 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
499def flatten_schema( 500 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 501) -> t.List[t.List[str]]: 502 tables = [] 503 keys = keys or [] 504 505 for k, v in schema.items(): 506 if depth >= 2: 507 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 508 elif depth == 1: 509 tables.append(keys + [k]) 510 511 return tables
514def nested_get( 515 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 516) -> t.Optional[t.Any]: 517 """ 518 Get a value for a nested dictionary. 519 520 Args: 521 d: the dictionary to search. 522 *path: tuples of (name, key), where: 523 `key` is the key in the dictionary to get. 524 `name` is a string to use in the error if `key` isn't found. 525 526 Returns: 527 The value or None if it doesn't exist. 528 """ 529 for name, key in path: 530 d = d.get(key) # type: ignore 531 if d is None: 532 if raise_on_missing: 533 name = "table" if name == "this" else name 534 raise ValueError(f"Unknown {name}: {key}") 535 return None 536 537 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
540def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 541 """ 542 In-place set a value for a nested dictionary 543 544 Example: 545 >>> nested_set({}, ["top_key", "second_key"], "value") 546 {'top_key': {'second_key': 'value'}} 547 548 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 549 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 550 551 Args: 552 d: dictionary to update. 553 keys: the keys that makeup the path to `value`. 554 value: the value to set in the dictionary for the given key path. 555 556 Returns: 557 The (possibly) updated dictionary. 558 """ 559 if not keys: 560 return d 561 562 if len(keys) == 1: 563 d[keys[0]] = value 564 return d 565 566 subd = d 567 for key in keys[:-1]: 568 if key not in subd: 569 subd = subd.setdefault(key, {}) 570 else: 571 subd = subd[key] 572 573 subd[keys[-1]] = value 574 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.