sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot._typing import T 9from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE 10from sqlglot.errors import ParseError, SchemaError 11from sqlglot.helper import dict_depth 12from sqlglot.trie import in_trie, new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot.dataframe.sql.types import StructType 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 19 20TABLE_ARGS = ("this", "db", "catalog") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 ) -> None: 35 """ 36 Register or update a table. Some implementing classes may require column information to also be provided. 37 38 Args: 39 table: the `Table` expression instance or string representing the table. 40 column_mapping: a column mapping that describes the structure of the table. 41 dialect: the SQL dialect that will be used to parse `table` if it's a string. 42 """ 43 44 @abc.abstractmethod 45 def column_names( 46 self, 47 table: exp.Table | str, 48 only_visible: bool = False, 49 dialect: DialectType = None, 50 ) -> t.List[str]: 51 """ 52 Get the column names for a table. 53 54 Args: 55 table: the `Table` expression instance. 56 only_visible: whether to include invisible columns. 57 dialect: the SQL dialect that will be used to parse `table` if it's a string. 58 59 Returns: 60 The list of column names. 61 """ 62 63 @abc.abstractmethod 64 def get_column_type( 65 self, 66 table: exp.Table | str, 67 column: exp.Column, 68 dialect: DialectType = None, 69 ) -> exp.DataType: 70 """ 71 Get the `sqlglot.exp.DataType` type of a column in the schema. 72 73 Args: 74 table: the source table. 75 column: the target column. 76 dialect: the SQL dialect that will be used to parse `table` if it's a string. 77 78 Returns: 79 The resulting column type. 80 """ 81 82 @property 83 @abc.abstractmethod 84 def supported_table_args(self) -> t.Tuple[str, ...]: 85 """ 86 Table arguments this schema support, e.g. `("this", "db", "catalog")` 87 """ 88 89 @property 90 def empty(self) -> bool: 91 """Returns whether or not the schema is empty.""" 92 return True 93 94 95class AbstractMappingSchema(t.Generic[T]): 96 def __init__( 97 self, 98 mapping: t.Optional[t.Dict] = None, 99 ) -> None: 100 self.mapping = mapping or {} 101 self.mapping_trie = new_trie( 102 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 103 ) 104 self._supported_table_args: t.Tuple[str, ...] = tuple() 105 106 @property 107 def empty(self) -> bool: 108 return not self.mapping 109 110 def _depth(self) -> int: 111 return dict_depth(self.mapping) 112 113 @property 114 def supported_table_args(self) -> t.Tuple[str, ...]: 115 if not self._supported_table_args and self.mapping: 116 depth = self._depth() 117 118 if not depth: # None 119 self._supported_table_args = tuple() 120 elif 1 <= depth <= 3: 121 self._supported_table_args = TABLE_ARGS[:depth] 122 else: 123 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 124 125 return self._supported_table_args 126 127 def table_parts(self, table: exp.Table) -> t.List[str]: 128 if isinstance(table.this, exp.ReadCSV): 129 return [table.this.name] 130 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 131 132 def find( 133 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 134 ) -> t.Optional[T]: 135 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 136 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 137 138 if value == 0: 139 return None 140 141 if value == 1: 142 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 143 144 if len(possibilities) == 1: 145 parts.extend(possibilities[0]) 146 else: 147 message = ", ".join(".".join(parts) for parts in possibilities) 148 if raise_on_missing: 149 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 150 return None 151 152 return self.nested_get(parts, raise_on_missing=raise_on_missing) 153 154 def nested_get( 155 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 156 ) -> t.Optional[t.Any]: 157 return nested_get( 158 d or self.mapping, 159 *zip(self.supported_table_args, reversed(parts)), 160 raise_on_missing=raise_on_missing, 161 ) 162 163 164class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 165 """ 166 Schema based on a nested mapping. 167 168 Args: 169 schema: Mapping in one of the following forms: 170 1. {table: {col: type}} 171 2. {db: {table: {col: type}}} 172 3. {catalog: {db: {table: {col: type}}}} 173 4. None - Tables will be added later 174 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 175 are assumed to be visible. The nesting should mirror that of the schema: 176 1. {table: set(*cols)}} 177 2. {db: {table: set(*cols)}}} 178 3. {catalog: {db: {table: set(*cols)}}}} 179 dialect: The dialect to be used for custom type mappings & parsing string arguments. 180 normalize: Whether to normalize identifier names according to the given dialect or not. 181 """ 182 183 def __init__( 184 self, 185 schema: t.Optional[t.Dict] = None, 186 visible: t.Optional[t.Dict] = None, 187 dialect: DialectType = None, 188 normalize: bool = True, 189 ) -> None: 190 self.dialect = dialect 191 self.visible = visible or {} 192 self.normalize = normalize 193 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 194 195 super().__init__(self._normalize(schema or {})) 196 197 @classmethod 198 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 199 return MappingSchema( 200 schema=mapping_schema.mapping, 201 visible=mapping_schema.visible, 202 dialect=mapping_schema.dialect, 203 ) 204 205 def copy(self, **kwargs) -> MappingSchema: 206 return MappingSchema( 207 **{ # type: ignore 208 "schema": self.mapping.copy(), 209 "visible": self.visible.copy(), 210 "dialect": self.dialect, 211 **kwargs, 212 } 213 ) 214 215 def add_table( 216 self, 217 table: exp.Table | str, 218 column_mapping: t.Optional[ColumnMapping] = None, 219 dialect: DialectType = None, 220 ) -> None: 221 """ 222 Register or update a table. Updates are only performed if a new column mapping is provided. 223 224 Args: 225 table: the `Table` expression instance or string representing the table. 226 column_mapping: a column mapping that describes the structure of the table. 227 dialect: the SQL dialect that will be used to parse `table` if it's a string. 228 """ 229 normalized_table = self._normalize_table( 230 self._ensure_table(table, dialect=dialect), dialect=dialect 231 ) 232 normalized_column_mapping = { 233 self._normalize_name(key, dialect=dialect): value 234 for key, value in ensure_column_mapping(column_mapping).items() 235 } 236 237 schema = self.find(normalized_table, raise_on_missing=False) 238 if schema and not normalized_column_mapping: 239 return 240 241 parts = self.table_parts(normalized_table) 242 243 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 244 new_trie([parts], self.mapping_trie) 245 246 def column_names( 247 self, 248 table: exp.Table | str, 249 only_visible: bool = False, 250 dialect: DialectType = None, 251 ) -> t.List[str]: 252 normalized_table = self._normalize_table( 253 self._ensure_table(table, dialect=dialect), dialect=dialect 254 ) 255 256 schema = self.find(normalized_table) 257 if schema is None: 258 return [] 259 260 if not only_visible or not self.visible: 261 return list(schema) 262 263 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 264 return [col for col in schema if col in visible] 265 266 def get_column_type( 267 self, 268 table: exp.Table | str, 269 column: exp.Column, 270 dialect: DialectType = None, 271 ) -> exp.DataType: 272 normalized_table = self._normalize_table( 273 self._ensure_table(table, dialect=dialect), dialect=dialect 274 ) 275 normalized_column_name = self._normalize_name( 276 column if isinstance(column, str) else column.this, dialect=dialect 277 ) 278 279 table_schema = self.find(normalized_table, raise_on_missing=False) 280 if table_schema: 281 column_type = table_schema.get(normalized_column_name) 282 283 if isinstance(column_type, exp.DataType): 284 return column_type 285 elif isinstance(column_type, str): 286 return self._to_data_type(column_type.upper(), dialect=dialect) 287 288 raise SchemaError(f"Unknown column type '{column_type}'") 289 290 return exp.DataType.build("unknown") 291 292 def _normalize(self, schema: t.Dict) -> t.Dict: 293 """ 294 Converts all identifiers in the schema into lowercase, unless they're quoted. 295 296 Args: 297 schema: the schema to normalize. 298 299 Returns: 300 The normalized schema mapping. 301 """ 302 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 303 304 normalized_mapping: t.Dict = {} 305 for keys in flattened_schema: 306 columns = nested_get(schema, *zip(keys, keys)) 307 assert columns is not None 308 309 normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] 310 for column_name, column_type in columns.items(): 311 nested_set( 312 normalized_mapping, 313 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 314 column_type, 315 ) 316 317 return normalized_mapping 318 319 def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: 320 normalized_table = table.copy() 321 322 for arg in TABLE_ARGS: 323 value = normalized_table.args.get(arg) 324 if isinstance(value, (str, exp.Identifier)): 325 normalized_table.set( 326 arg, exp.to_identifier(self._normalize_name(value, dialect=dialect)) 327 ) 328 329 return normalized_table 330 331 def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: 332 dialect = dialect or self.dialect 333 334 try: 335 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 336 except ParseError: 337 return name if isinstance(name, str) else name.name 338 339 name = identifier.name 340 341 if not self.normalize or identifier.quoted: 342 return name 343 344 return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() 345 346 def _depth(self) -> int: 347 # The columns themselves are a mapping, but we don't want to include those 348 return super()._depth() - 1 349 350 def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: 351 return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect) 352 353 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 354 """ 355 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 356 357 Args: 358 schema_type: the type we want to convert. 359 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 360 361 Returns: 362 The resulting expression type. 363 """ 364 if schema_type not in self._type_mapping_cache: 365 dialect = dialect or self.dialect 366 367 try: 368 expression = exp.DataType.build(schema_type, dialect=dialect) 369 self._type_mapping_cache[schema_type] = expression 370 except AttributeError: 371 in_dialect = f" in dialect {dialect}" if dialect else "" 372 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 373 374 return self._type_mapping_cache[schema_type] 375 376 377def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 378 if isinstance(schema, Schema): 379 return schema 380 381 return MappingSchema(schema, **kwargs) 382 383 384def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 385 if mapping is None: 386 return {} 387 elif isinstance(mapping, dict): 388 return mapping 389 elif isinstance(mapping, str): 390 col_name_type_strs = [x.strip() for x in mapping.split(",")] 391 return { 392 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 393 for name_type_str in col_name_type_strs 394 } 395 # Check if mapping looks like a DataFrame StructType 396 elif hasattr(mapping, "simpleString"): 397 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 398 elif isinstance(mapping, list): 399 return {x.strip(): None for x in mapping} 400 401 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 402 403 404def flatten_schema( 405 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 406) -> t.List[t.List[str]]: 407 tables = [] 408 keys = keys or [] 409 410 for k, v in schema.items(): 411 if depth >= 2: 412 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 413 elif depth == 1: 414 tables.append(keys + [k]) 415 416 return tables 417 418 419def nested_get( 420 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 421) -> t.Optional[t.Any]: 422 """ 423 Get a value for a nested dictionary. 424 425 Args: 426 d: the dictionary to search. 427 *path: tuples of (name, key), where: 428 `key` is the key in the dictionary to get. 429 `name` is a string to use in the error if `key` isn't found. 430 431 Returns: 432 The value or None if it doesn't exist. 433 """ 434 for name, key in path: 435 d = d.get(key) # type: ignore 436 if d is None: 437 if raise_on_missing: 438 name = "table" if name == "this" else name 439 raise ValueError(f"Unknown {name}: {key}") 440 return None 441 442 return d 443 444 445def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 446 """ 447 In-place set a value for a nested dictionary 448 449 Example: 450 >>> nested_set({}, ["top_key", "second_key"], "value") 451 {'top_key': {'second_key': 'value'}} 452 453 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 454 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 455 456 Args: 457 d: dictionary to update. 458 keys: the keys that makeup the path to `value`. 459 value: the value to set in the dictionary for the given key path. 460 461 Returns: 462 The (possibly) updated dictionary. 463 """ 464 if not keys: 465 return d 466 467 if len(keys) == 1: 468 d[keys[0]] = value 469 return d 470 471 subd = d 472 for key in keys[:-1]: 473 if key not in subd: 474 subd = subd.setdefault(key, {}) 475 else: 476 subd = subd[key] 477 478 subd[keys[-1]] = value 479 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 dialect: DialectType 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 ) -> None: 36 """ 37 Register or update a table. Some implementing classes may require column information to also be provided. 38 39 Args: 40 table: the `Table` expression instance or string representing the table. 41 column_mapping: a column mapping that describes the structure of the table. 42 dialect: the SQL dialect that will be used to parse `table` if it's a string. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 ) -> t.List[str]: 52 """ 53 Get the column names for a table. 54 55 Args: 56 table: the `Table` expression instance. 57 only_visible: whether to include invisible columns. 58 dialect: the SQL dialect that will be used to parse `table` if it's a string. 59 60 Returns: 61 The list of column names. 62 """ 63 64 @abc.abstractmethod 65 def get_column_type( 66 self, 67 table: exp.Table | str, 68 column: exp.Column, 69 dialect: DialectType = None, 70 ) -> exp.DataType: 71 """ 72 Get the `sqlglot.exp.DataType` type of a column in the schema. 73 74 Args: 75 table: the source table. 76 column: the target column. 77 dialect: the SQL dialect that will be used to parse `table` if it's a string. 78 79 Returns: 80 The resulting column type. 81 """ 82 83 @property 84 @abc.abstractmethod 85 def supported_table_args(self) -> t.Tuple[str, ...]: 86 """ 87 Table arguments this schema support, e.g. `("this", "db", "catalog")` 88 """ 89 90 @property 91 def empty(self) -> bool: 92 """Returns whether or not the schema is empty.""" 93 return True
Abstract base class for database schemas
29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 ) -> None: 36 """ 37 Register or update a table. Some implementing classes may require column information to also be provided. 38 39 Args: 40 table: the `Table` expression instance or string representing the table. 41 column_mapping: a column mapping that describes the structure of the table. 42 dialect: the SQL dialect that will be used to parse `table` if it's a string. 43 """
Register or update a table. Some implementing classes may require column information to also be provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 ) -> t.List[str]: 52 """ 53 Get the column names for a table. 54 55 Args: 56 table: the `Table` expression instance. 57 only_visible: whether to include invisible columns. 58 dialect: the SQL dialect that will be used to parse `table` if it's a string. 59 60 Returns: 61 The list of column names. 62 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The list of column names.
64 @abc.abstractmethod 65 def get_column_type( 66 self, 67 table: exp.Table | str, 68 column: exp.Column, 69 dialect: DialectType = None, 70 ) -> exp.DataType: 71 """ 72 Get the `sqlglot.exp.DataType` type of a column in the schema. 73 74 Args: 75 table: the source table. 76 column: the target column. 77 dialect: the SQL dialect that will be used to parse `table` if it's a string. 78 79 Returns: 80 The resulting column type. 81 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The resulting column type.
96class AbstractMappingSchema(t.Generic[T]): 97 def __init__( 98 self, 99 mapping: t.Optional[t.Dict] = None, 100 ) -> None: 101 self.mapping = mapping or {} 102 self.mapping_trie = new_trie( 103 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 104 ) 105 self._supported_table_args: t.Tuple[str, ...] = tuple() 106 107 @property 108 def empty(self) -> bool: 109 return not self.mapping 110 111 def _depth(self) -> int: 112 return dict_depth(self.mapping) 113 114 @property 115 def supported_table_args(self) -> t.Tuple[str, ...]: 116 if not self._supported_table_args and self.mapping: 117 depth = self._depth() 118 119 if not depth: # None 120 self._supported_table_args = tuple() 121 elif 1 <= depth <= 3: 122 self._supported_table_args = TABLE_ARGS[:depth] 123 else: 124 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 125 126 return self._supported_table_args 127 128 def table_parts(self, table: exp.Table) -> t.List[str]: 129 if isinstance(table.this, exp.ReadCSV): 130 return [table.this.name] 131 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 132 133 def find( 134 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 135 ) -> t.Optional[T]: 136 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 137 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 138 139 if value == 0: 140 return None 141 142 if value == 1: 143 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 144 145 if len(possibilities) == 1: 146 parts.extend(possibilities[0]) 147 else: 148 message = ", ".join(".".join(parts) for parts in possibilities) 149 if raise_on_missing: 150 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 151 return None 152 153 return self.nested_get(parts, raise_on_missing=raise_on_missing) 154 155 def nested_get( 156 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 157 ) -> t.Optional[t.Any]: 158 return nested_get( 159 d or self.mapping, 160 *zip(self.supported_table_args, reversed(parts)), 161 raise_on_missing=raise_on_missing, 162 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
133 def find( 134 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 135 ) -> t.Optional[T]: 136 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 137 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 138 139 if value == 0: 140 return None 141 142 if value == 1: 143 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 144 145 if len(possibilities) == 1: 146 parts.extend(possibilities[0]) 147 else: 148 message = ", ".join(".".join(parts) for parts in possibilities) 149 if raise_on_missing: 150 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 151 return None 152 153 return self.nested_get(parts, raise_on_missing=raise_on_missing)
165class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 166 """ 167 Schema based on a nested mapping. 168 169 Args: 170 schema: Mapping in one of the following forms: 171 1. {table: {col: type}} 172 2. {db: {table: {col: type}}} 173 3. {catalog: {db: {table: {col: type}}}} 174 4. None - Tables will be added later 175 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 176 are assumed to be visible. The nesting should mirror that of the schema: 177 1. {table: set(*cols)}} 178 2. {db: {table: set(*cols)}}} 179 3. {catalog: {db: {table: set(*cols)}}}} 180 dialect: The dialect to be used for custom type mappings & parsing string arguments. 181 normalize: Whether to normalize identifier names according to the given dialect or not. 182 """ 183 184 def __init__( 185 self, 186 schema: t.Optional[t.Dict] = None, 187 visible: t.Optional[t.Dict] = None, 188 dialect: DialectType = None, 189 normalize: bool = True, 190 ) -> None: 191 self.dialect = dialect 192 self.visible = visible or {} 193 self.normalize = normalize 194 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 195 196 super().__init__(self._normalize(schema or {})) 197 198 @classmethod 199 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 200 return MappingSchema( 201 schema=mapping_schema.mapping, 202 visible=mapping_schema.visible, 203 dialect=mapping_schema.dialect, 204 ) 205 206 def copy(self, **kwargs) -> MappingSchema: 207 return MappingSchema( 208 **{ # type: ignore 209 "schema": self.mapping.copy(), 210 "visible": self.visible.copy(), 211 "dialect": self.dialect, 212 **kwargs, 213 } 214 ) 215 216 def add_table( 217 self, 218 table: exp.Table | str, 219 column_mapping: t.Optional[ColumnMapping] = None, 220 dialect: DialectType = None, 221 ) -> None: 222 """ 223 Register or update a table. Updates are only performed if a new column mapping is provided. 224 225 Args: 226 table: the `Table` expression instance or string representing the table. 227 column_mapping: a column mapping that describes the structure of the table. 228 dialect: the SQL dialect that will be used to parse `table` if it's a string. 229 """ 230 normalized_table = self._normalize_table( 231 self._ensure_table(table, dialect=dialect), dialect=dialect 232 ) 233 normalized_column_mapping = { 234 self._normalize_name(key, dialect=dialect): value 235 for key, value in ensure_column_mapping(column_mapping).items() 236 } 237 238 schema = self.find(normalized_table, raise_on_missing=False) 239 if schema and not normalized_column_mapping: 240 return 241 242 parts = self.table_parts(normalized_table) 243 244 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 245 new_trie([parts], self.mapping_trie) 246 247 def column_names( 248 self, 249 table: exp.Table | str, 250 only_visible: bool = False, 251 dialect: DialectType = None, 252 ) -> t.List[str]: 253 normalized_table = self._normalize_table( 254 self._ensure_table(table, dialect=dialect), dialect=dialect 255 ) 256 257 schema = self.find(normalized_table) 258 if schema is None: 259 return [] 260 261 if not only_visible or not self.visible: 262 return list(schema) 263 264 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 265 return [col for col in schema if col in visible] 266 267 def get_column_type( 268 self, 269 table: exp.Table | str, 270 column: exp.Column, 271 dialect: DialectType = None, 272 ) -> exp.DataType: 273 normalized_table = self._normalize_table( 274 self._ensure_table(table, dialect=dialect), dialect=dialect 275 ) 276 normalized_column_name = self._normalize_name( 277 column if isinstance(column, str) else column.this, dialect=dialect 278 ) 279 280 table_schema = self.find(normalized_table, raise_on_missing=False) 281 if table_schema: 282 column_type = table_schema.get(normalized_column_name) 283 284 if isinstance(column_type, exp.DataType): 285 return column_type 286 elif isinstance(column_type, str): 287 return self._to_data_type(column_type.upper(), dialect=dialect) 288 289 raise SchemaError(f"Unknown column type '{column_type}'") 290 291 return exp.DataType.build("unknown") 292 293 def _normalize(self, schema: t.Dict) -> t.Dict: 294 """ 295 Converts all identifiers in the schema into lowercase, unless they're quoted. 296 297 Args: 298 schema: the schema to normalize. 299 300 Returns: 301 The normalized schema mapping. 302 """ 303 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 304 305 normalized_mapping: t.Dict = {} 306 for keys in flattened_schema: 307 columns = nested_get(schema, *zip(keys, keys)) 308 assert columns is not None 309 310 normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] 311 for column_name, column_type in columns.items(): 312 nested_set( 313 normalized_mapping, 314 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 315 column_type, 316 ) 317 318 return normalized_mapping 319 320 def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: 321 normalized_table = table.copy() 322 323 for arg in TABLE_ARGS: 324 value = normalized_table.args.get(arg) 325 if isinstance(value, (str, exp.Identifier)): 326 normalized_table.set( 327 arg, exp.to_identifier(self._normalize_name(value, dialect=dialect)) 328 ) 329 330 return normalized_table 331 332 def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: 333 dialect = dialect or self.dialect 334 335 try: 336 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 337 except ParseError: 338 return name if isinstance(name, str) else name.name 339 340 name = identifier.name 341 342 if not self.normalize or identifier.quoted: 343 return name 344 345 return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() 346 347 def _depth(self) -> int: 348 # The columns themselves are a mapping, but we don't want to include those 349 return super()._depth() - 1 350 351 def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: 352 return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect) 353 354 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 355 """ 356 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 357 358 Args: 359 schema_type: the type we want to convert. 360 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 361 362 Returns: 363 The resulting expression type. 364 """ 365 if schema_type not in self._type_mapping_cache: 366 dialect = dialect or self.dialect 367 368 try: 369 expression = exp.DataType.build(schema_type, dialect=dialect) 370 self._type_mapping_cache[schema_type] = expression 371 except AttributeError: 372 in_dialect = f" in dialect {dialect}" if dialect else "" 373 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 374 375 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
184 def __init__( 185 self, 186 schema: t.Optional[t.Dict] = None, 187 visible: t.Optional[t.Dict] = None, 188 dialect: DialectType = None, 189 normalize: bool = True, 190 ) -> None: 191 self.dialect = dialect 192 self.visible = visible or {} 193 self.normalize = normalize 194 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 195 196 super().__init__(self._normalize(schema or {}))
216 def add_table( 217 self, 218 table: exp.Table | str, 219 column_mapping: t.Optional[ColumnMapping] = None, 220 dialect: DialectType = None, 221 ) -> None: 222 """ 223 Register or update a table. Updates are only performed if a new column mapping is provided. 224 225 Args: 226 table: the `Table` expression instance or string representing the table. 227 column_mapping: a column mapping that describes the structure of the table. 228 dialect: the SQL dialect that will be used to parse `table` if it's a string. 229 """ 230 normalized_table = self._normalize_table( 231 self._ensure_table(table, dialect=dialect), dialect=dialect 232 ) 233 normalized_column_mapping = { 234 self._normalize_name(key, dialect=dialect): value 235 for key, value in ensure_column_mapping(column_mapping).items() 236 } 237 238 schema = self.find(normalized_table, raise_on_missing=False) 239 if schema and not normalized_column_mapping: 240 return 241 242 parts = self.table_parts(normalized_table) 243 244 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 245 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
247 def column_names( 248 self, 249 table: exp.Table | str, 250 only_visible: bool = False, 251 dialect: DialectType = None, 252 ) -> t.List[str]: 253 normalized_table = self._normalize_table( 254 self._ensure_table(table, dialect=dialect), dialect=dialect 255 ) 256 257 schema = self.find(normalized_table) 258 if schema is None: 259 return [] 260 261 if not only_visible or not self.visible: 262 return list(schema) 263 264 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 265 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The list of column names.
267 def get_column_type( 268 self, 269 table: exp.Table | str, 270 column: exp.Column, 271 dialect: DialectType = None, 272 ) -> exp.DataType: 273 normalized_table = self._normalize_table( 274 self._ensure_table(table, dialect=dialect), dialect=dialect 275 ) 276 normalized_column_name = self._normalize_name( 277 column if isinstance(column, str) else column.this, dialect=dialect 278 ) 279 280 table_schema = self.find(normalized_table, raise_on_missing=False) 281 if table_schema: 282 column_type = table_schema.get(normalized_column_name) 283 284 if isinstance(column_type, exp.DataType): 285 return column_type 286 elif isinstance(column_type, str): 287 return self._to_data_type(column_type.upper(), dialect=dialect) 288 289 raise SchemaError(f"Unknown column type '{column_type}'") 290 291 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The resulting column type.
Inherited Members
385def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 386 if mapping is None: 387 return {} 388 elif isinstance(mapping, dict): 389 return mapping 390 elif isinstance(mapping, str): 391 col_name_type_strs = [x.strip() for x in mapping.split(",")] 392 return { 393 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 394 for name_type_str in col_name_type_strs 395 } 396 # Check if mapping looks like a DataFrame StructType 397 elif hasattr(mapping, "simpleString"): 398 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 399 elif isinstance(mapping, list): 400 return {x.strip(): None for x in mapping} 401 402 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
405def flatten_schema( 406 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 407) -> t.List[t.List[str]]: 408 tables = [] 409 keys = keys or [] 410 411 for k, v in schema.items(): 412 if depth >= 2: 413 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 414 elif depth == 1: 415 tables.append(keys + [k]) 416 417 return tables
420def nested_get( 421 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 422) -> t.Optional[t.Any]: 423 """ 424 Get a value for a nested dictionary. 425 426 Args: 427 d: the dictionary to search. 428 *path: tuples of (name, key), where: 429 `key` is the key in the dictionary to get. 430 `name` is a string to use in the error if `key` isn't found. 431 432 Returns: 433 The value or None if it doesn't exist. 434 """ 435 for name, key in path: 436 d = d.get(key) # type: ignore 437 if d is None: 438 if raise_on_missing: 439 name = "table" if name == "this" else name 440 raise ValueError(f"Unknown {name}: {key}") 441 return None 442 443 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
446def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 447 """ 448 In-place set a value for a nested dictionary 449 450 Example: 451 >>> nested_set({}, ["top_key", "second_key"], "value") 452 {'top_key': {'second_key': 'value'}} 453 454 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 455 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 456 457 Args: 458 d: dictionary to update. 459 keys: the keys that makeup the path to `value`. 460 value: the value to set in the dictionary for the given key path. 461 462 Returns: 463 The (possibly) updated dictionary. 464 """ 465 if not keys: 466 return d 467 468 if len(keys) == 1: 469 d[keys[0]] = value 470 return d 471 472 subd = d 473 for key in keys[:-1]: 474 if key not in subd: 475 subd = subd.setdefault(key, {}) 476 else: 477 subd = subd[key] 478 479 subd[keys[-1]] = value 480 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.