sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot.errors import ParseError, SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18TABLE_ARGS = ("this", "db", "catalog") 19 20T = t.TypeVar("T") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 @abc.abstractmethod 27 def add_table( 28 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 29 ) -> None: 30 """ 31 Register or update a table. Some implementing classes may require column information to also be provided. 32 33 Args: 34 table: table expression instance or string representing the table. 35 column_mapping: a column mapping that describes the structure of the table. 36 """ 37 38 @abc.abstractmethod 39 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 40 """ 41 Get the column names for a table. 42 43 Args: 44 table: the `Table` expression instance. 45 only_visible: whether to include invisible columns. 46 47 Returns: 48 The list of column names. 49 """ 50 51 @abc.abstractmethod 52 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 53 """ 54 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 55 56 Args: 57 table: the source table. 58 column: the target column. 59 60 Returns: 61 The resulting column type. 62 """ 63 64 @property 65 def supported_table_args(self) -> t.Tuple[str, ...]: 66 """ 67 Table arguments this schema support, e.g. `("this", "db", "catalog")` 68 """ 69 raise NotImplementedError 70 71 72class AbstractMappingSchema(t.Generic[T]): 73 def __init__( 74 self, 75 mapping: dict | None = None, 76 ) -> None: 77 self.mapping = mapping or {} 78 self.mapping_trie = new_trie( 79 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 80 ) 81 self._supported_table_args: t.Tuple[str, ...] = tuple() 82 83 def _depth(self) -> int: 84 return dict_depth(self.mapping) 85 86 @property 87 def supported_table_args(self) -> t.Tuple[str, ...]: 88 if not self._supported_table_args and self.mapping: 89 depth = self._depth() 90 91 if not depth: # None 92 self._supported_table_args = tuple() 93 elif 1 <= depth <= 3: 94 self._supported_table_args = TABLE_ARGS[:depth] 95 else: 96 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 97 98 return self._supported_table_args 99 100 def table_parts(self, table: exp.Table) -> t.List[str]: 101 if isinstance(table.this, exp.ReadCSV): 102 return [table.this.name] 103 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 104 105 def find( 106 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 107 ) -> t.Optional[T]: 108 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 109 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 110 111 if value == 0: 112 return None 113 elif value == 1: 114 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 115 if len(possibilities) == 1: 116 parts.extend(possibilities[0]) 117 else: 118 message = ", ".join(".".join(parts) for parts in possibilities) 119 if raise_on_missing: 120 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 121 return None 122 return self._nested_get(parts, raise_on_missing=raise_on_missing) 123 124 def _nested_get( 125 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 126 ) -> t.Optional[t.Any]: 127 return _nested_get( 128 d or self.mapping, 129 *zip(self.supported_table_args, reversed(parts)), 130 raise_on_missing=raise_on_missing, 131 ) 132 133 134class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 135 """ 136 Schema based on a nested mapping. 137 138 Args: 139 schema (dict): Mapping in one of the following forms: 140 1. {table: {col: type}} 141 2. {db: {table: {col: type}}} 142 3. {catalog: {db: {table: {col: type}}}} 143 4. None - Tables will be added later 144 visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns 145 are assumed to be visible. The nesting should mirror that of the schema: 146 1. {table: set(*cols)}} 147 2. {db: {table: set(*cols)}}} 148 3. {catalog: {db: {table: set(*cols)}}}} 149 dialect (str): The dialect to be used for custom type mappings. 150 """ 151 152 def __init__( 153 self, 154 schema: t.Optional[t.Dict] = None, 155 visible: t.Optional[t.Dict] = None, 156 dialect: DialectType = None, 157 ) -> None: 158 self.dialect = dialect 159 self.visible = visible or {} 160 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 161 super().__init__(self._normalize(schema or {})) 162 163 @classmethod 164 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 165 return MappingSchema( 166 schema=mapping_schema.mapping, 167 visible=mapping_schema.visible, 168 dialect=mapping_schema.dialect, 169 ) 170 171 def copy(self, **kwargs) -> MappingSchema: 172 return MappingSchema( 173 **{ # type: ignore 174 "schema": self.mapping.copy(), 175 "visible": self.visible.copy(), 176 "dialect": self.dialect, 177 **kwargs, 178 } 179 ) 180 181 def add_table( 182 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 183 ) -> None: 184 """ 185 Register or update a table. Updates are only performed if a new column mapping is provided. 186 187 Args: 188 table: the `Table` expression instance or string representing the table. 189 column_mapping: a column mapping that describes the structure of the table. 190 """ 191 normalized_table = self._normalize_table(self._ensure_table(table)) 192 normalized_column_mapping = { 193 self._normalize_name(key): value 194 for key, value in ensure_column_mapping(column_mapping).items() 195 } 196 197 schema = self.find(normalized_table, raise_on_missing=False) 198 if schema and not normalized_column_mapping: 199 return 200 201 parts = self.table_parts(normalized_table) 202 203 _nested_set( 204 self.mapping, 205 tuple(reversed(parts)), 206 normalized_column_mapping, 207 ) 208 new_trie([parts], self.mapping_trie) 209 210 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 211 table_ = self._normalize_table(self._ensure_table(table)) 212 schema = self.find(table_) 213 214 if schema is None: 215 return [] 216 217 if not only_visible or not self.visible: 218 return list(schema) 219 220 visible = self._nested_get(self.table_parts(table_), self.visible) 221 return [col for col in schema if col in visible] # type: ignore 222 223 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 224 column_name = self._normalize_name(column if isinstance(column, str) else column.this) 225 table_ = self._normalize_table(self._ensure_table(table)) 226 227 table_schema = self.find(table_, raise_on_missing=False) 228 if table_schema: 229 column_type = table_schema.get(column_name) 230 231 if isinstance(column_type, exp.DataType): 232 return column_type 233 elif isinstance(column_type, str): 234 return self._to_data_type(column_type.upper()) 235 raise SchemaError(f"Unknown column type '{column_type}'") 236 237 return exp.DataType.build("unknown") 238 239 def _normalize(self, schema: t.Dict) -> t.Dict: 240 """ 241 Converts all identifiers in the schema into lowercase, unless they're quoted. 242 243 Args: 244 schema: the schema to normalize. 245 246 Returns: 247 The normalized schema mapping. 248 """ 249 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 250 251 normalized_mapping: t.Dict = {} 252 for keys in flattened_schema: 253 columns = _nested_get(schema, *zip(keys, keys)) 254 assert columns is not None 255 256 normalized_keys = [self._normalize_name(key) for key in keys] 257 for column_name, column_type in columns.items(): 258 _nested_set( 259 normalized_mapping, 260 normalized_keys + [self._normalize_name(column_name)], 261 column_type, 262 ) 263 264 return normalized_mapping 265 266 def _normalize_table(self, table: exp.Table) -> exp.Table: 267 normalized_table = table.copy() 268 for arg in TABLE_ARGS: 269 value = normalized_table.args.get(arg) 270 if isinstance(value, (str, exp.Identifier)): 271 normalized_table.set(arg, self._normalize_name(value)) 272 273 return normalized_table 274 275 def _normalize_name(self, name: str | exp.Identifier) -> str: 276 try: 277 identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) 278 except ParseError: 279 return name if isinstance(name, str) else name.name 280 281 return identifier.name if identifier.quoted else identifier.name.lower() 282 283 def _depth(self) -> int: 284 # The columns themselves are a mapping, but we don't want to include those 285 return super()._depth() - 1 286 287 def _ensure_table(self, table: exp.Table | str) -> exp.Table: 288 if isinstance(table, exp.Table): 289 return table 290 291 table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) 292 if not table_: 293 raise SchemaError(f"Not a valid table '{table}'") 294 295 return table_ 296 297 def _to_data_type(self, schema_type: str) -> exp.DataType: 298 """ 299 Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. 300 301 Args: 302 schema_type: the type we want to convert. 303 304 Returns: 305 The resulting expression type. 306 """ 307 if schema_type not in self._type_mapping_cache: 308 try: 309 expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) 310 if expression is None: 311 raise ValueError(f"Could not parse {schema_type}") 312 self._type_mapping_cache[schema_type] = expression # type: ignore 313 except AttributeError: 314 raise SchemaError(f"Failed to convert type {schema_type}") 315 316 return self._type_mapping_cache[schema_type] 317 318 319def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: 320 if isinstance(schema, Schema): 321 return schema 322 323 return MappingSchema(schema, dialect=dialect) 324 325 326def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 327 if isinstance(mapping, dict): 328 return mapping 329 elif isinstance(mapping, str): 330 col_name_type_strs = [x.strip() for x in mapping.split(",")] 331 return { 332 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 333 for name_type_str in col_name_type_strs 334 } 335 # Check if mapping looks like a DataFrame StructType 336 elif hasattr(mapping, "simpleString"): 337 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore 338 elif isinstance(mapping, list): 339 return {x.strip(): None for x in mapping} 340 elif mapping is None: 341 return {} 342 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 343 344 345def flatten_schema( 346 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 347) -> t.List[t.List[str]]: 348 tables = [] 349 keys = keys or [] 350 351 for k, v in schema.items(): 352 if depth >= 2: 353 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 354 elif depth == 1: 355 tables.append(keys + [k]) 356 return tables 357 358 359def _nested_get( 360 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 361) -> t.Optional[t.Any]: 362 """ 363 Get a value for a nested dictionary. 364 365 Args: 366 d: the dictionary to search. 367 *path: tuples of (name, key), where: 368 `key` is the key in the dictionary to get. 369 `name` is a string to use in the error if `key` isn't found. 370 371 Returns: 372 The value or None if it doesn't exist. 373 """ 374 for name, key in path: 375 d = d.get(key) # type: ignore 376 if d is None: 377 if raise_on_missing: 378 name = "table" if name == "this" else name 379 raise ValueError(f"Unknown {name}: {key}") 380 return None 381 return d 382 383 384def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 385 """ 386 In-place set a value for a nested dictionary 387 388 Example: 389 >>> _nested_set({}, ["top_key", "second_key"], "value") 390 {'top_key': {'second_key': 'value'}} 391 392 >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 393 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 394 395 Args: 396 d: dictionary to update. 397 keys: the keys that makeup the path to `value`. 398 value: the value to set in the dictionary for the given key path. 399 400 Returns: 401 The (possibly) updated dictionary. 402 """ 403 if not keys: 404 return d 405 406 if len(keys) == 1: 407 d[keys[0]] = value 408 return d 409 410 subd = d 411 for key in keys[:-1]: 412 if key not in subd: 413 subd = subd.setdefault(key, {}) 414 else: 415 subd = subd[key] 416 417 subd[keys[-1]] = value 418 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 @abc.abstractmethod 28 def add_table( 29 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 30 ) -> None: 31 """ 32 Register or update a table. Some implementing classes may require column information to also be provided. 33 34 Args: 35 table: table expression instance or string representing the table. 36 column_mapping: a column mapping that describes the structure of the table. 37 """ 38 39 @abc.abstractmethod 40 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 41 """ 42 Get the column names for a table. 43 44 Args: 45 table: the `Table` expression instance. 46 only_visible: whether to include invisible columns. 47 48 Returns: 49 The list of column names. 50 """ 51 52 @abc.abstractmethod 53 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 54 """ 55 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 56 57 Args: 58 table: the source table. 59 column: the target column. 60 61 Returns: 62 The resulting column type. 63 """ 64 65 @property 66 def supported_table_args(self) -> t.Tuple[str, ...]: 67 """ 68 Table arguments this schema support, e.g. `("this", "db", "catalog")` 69 """ 70 raise NotImplementedError
Abstract base class for database schemas
27 @abc.abstractmethod 28 def add_table( 29 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 30 ) -> None: 31 """ 32 Register or update a table. Some implementing classes may require column information to also be provided. 33 34 Args: 35 table: table expression instance or string representing the table. 36 column_mapping: a column mapping that describes the structure of the table. 37 """
Register or update a table. Some implementing classes may require column information to also be provided.
Arguments:
- table: table expression instance or string representing the table.
- column_mapping: a column mapping that describes the structure of the table.
39 @abc.abstractmethod 40 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 41 """ 42 Get the column names for a table. 43 44 Args: 45 table: the `Table` expression instance. 46 only_visible: whether to include invisible columns. 47 48 Returns: 49 The list of column names. 50 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
Returns:
The list of column names.
52 @abc.abstractmethod 53 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 54 """ 55 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 56 57 Args: 58 table: the source table. 59 column: the target column. 60 61 Returns: 62 The resulting column type. 63 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
Returns:
The resulting column type.
73class AbstractMappingSchema(t.Generic[T]): 74 def __init__( 75 self, 76 mapping: dict | None = None, 77 ) -> None: 78 self.mapping = mapping or {} 79 self.mapping_trie = new_trie( 80 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 81 ) 82 self._supported_table_args: t.Tuple[str, ...] = tuple() 83 84 def _depth(self) -> int: 85 return dict_depth(self.mapping) 86 87 @property 88 def supported_table_args(self) -> t.Tuple[str, ...]: 89 if not self._supported_table_args and self.mapping: 90 depth = self._depth() 91 92 if not depth: # None 93 self._supported_table_args = tuple() 94 elif 1 <= depth <= 3: 95 self._supported_table_args = TABLE_ARGS[:depth] 96 else: 97 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 98 99 return self._supported_table_args 100 101 def table_parts(self, table: exp.Table) -> t.List[str]: 102 if isinstance(table.this, exp.ReadCSV): 103 return [table.this.name] 104 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 105 106 def find( 107 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 108 ) -> t.Optional[T]: 109 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 110 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 111 112 if value == 0: 113 return None 114 elif value == 1: 115 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 116 if len(possibilities) == 1: 117 parts.extend(possibilities[0]) 118 else: 119 message = ", ".join(".".join(parts) for parts in possibilities) 120 if raise_on_missing: 121 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 122 return None 123 return self._nested_get(parts, raise_on_missing=raise_on_missing) 124 125 def _nested_get( 126 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 127 ) -> t.Optional[t.Any]: 128 return _nested_get( 129 d or self.mapping, 130 *zip(self.supported_table_args, reversed(parts)), 131 raise_on_missing=raise_on_missing, 132 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
106 def find( 107 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 108 ) -> t.Optional[T]: 109 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 110 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 111 112 if value == 0: 113 return None 114 elif value == 1: 115 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 116 if len(possibilities) == 1: 117 parts.extend(possibilities[0]) 118 else: 119 message = ", ".join(".".join(parts) for parts in possibilities) 120 if raise_on_missing: 121 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 122 return None 123 return self._nested_get(parts, raise_on_missing=raise_on_missing)
135class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 136 """ 137 Schema based on a nested mapping. 138 139 Args: 140 schema (dict): Mapping in one of the following forms: 141 1. {table: {col: type}} 142 2. {db: {table: {col: type}}} 143 3. {catalog: {db: {table: {col: type}}}} 144 4. None - Tables will be added later 145 visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns 146 are assumed to be visible. The nesting should mirror that of the schema: 147 1. {table: set(*cols)}} 148 2. {db: {table: set(*cols)}}} 149 3. {catalog: {db: {table: set(*cols)}}}} 150 dialect (str): The dialect to be used for custom type mappings. 151 """ 152 153 def __init__( 154 self, 155 schema: t.Optional[t.Dict] = None, 156 visible: t.Optional[t.Dict] = None, 157 dialect: DialectType = None, 158 ) -> None: 159 self.dialect = dialect 160 self.visible = visible or {} 161 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 162 super().__init__(self._normalize(schema or {})) 163 164 @classmethod 165 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 166 return MappingSchema( 167 schema=mapping_schema.mapping, 168 visible=mapping_schema.visible, 169 dialect=mapping_schema.dialect, 170 ) 171 172 def copy(self, **kwargs) -> MappingSchema: 173 return MappingSchema( 174 **{ # type: ignore 175 "schema": self.mapping.copy(), 176 "visible": self.visible.copy(), 177 "dialect": self.dialect, 178 **kwargs, 179 } 180 ) 181 182 def add_table( 183 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 184 ) -> None: 185 """ 186 Register or update a table. Updates are only performed if a new column mapping is provided. 187 188 Args: 189 table: the `Table` expression instance or string representing the table. 190 column_mapping: a column mapping that describes the structure of the table. 191 """ 192 normalized_table = self._normalize_table(self._ensure_table(table)) 193 normalized_column_mapping = { 194 self._normalize_name(key): value 195 for key, value in ensure_column_mapping(column_mapping).items() 196 } 197 198 schema = self.find(normalized_table, raise_on_missing=False) 199 if schema and not normalized_column_mapping: 200 return 201 202 parts = self.table_parts(normalized_table) 203 204 _nested_set( 205 self.mapping, 206 tuple(reversed(parts)), 207 normalized_column_mapping, 208 ) 209 new_trie([parts], self.mapping_trie) 210 211 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 212 table_ = self._normalize_table(self._ensure_table(table)) 213 schema = self.find(table_) 214 215 if schema is None: 216 return [] 217 218 if not only_visible or not self.visible: 219 return list(schema) 220 221 visible = self._nested_get(self.table_parts(table_), self.visible) 222 return [col for col in schema if col in visible] # type: ignore 223 224 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 225 column_name = self._normalize_name(column if isinstance(column, str) else column.this) 226 table_ = self._normalize_table(self._ensure_table(table)) 227 228 table_schema = self.find(table_, raise_on_missing=False) 229 if table_schema: 230 column_type = table_schema.get(column_name) 231 232 if isinstance(column_type, exp.DataType): 233 return column_type 234 elif isinstance(column_type, str): 235 return self._to_data_type(column_type.upper()) 236 raise SchemaError(f"Unknown column type '{column_type}'") 237 238 return exp.DataType.build("unknown") 239 240 def _normalize(self, schema: t.Dict) -> t.Dict: 241 """ 242 Converts all identifiers in the schema into lowercase, unless they're quoted. 243 244 Args: 245 schema: the schema to normalize. 246 247 Returns: 248 The normalized schema mapping. 249 """ 250 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 251 252 normalized_mapping: t.Dict = {} 253 for keys in flattened_schema: 254 columns = _nested_get(schema, *zip(keys, keys)) 255 assert columns is not None 256 257 normalized_keys = [self._normalize_name(key) for key in keys] 258 for column_name, column_type in columns.items(): 259 _nested_set( 260 normalized_mapping, 261 normalized_keys + [self._normalize_name(column_name)], 262 column_type, 263 ) 264 265 return normalized_mapping 266 267 def _normalize_table(self, table: exp.Table) -> exp.Table: 268 normalized_table = table.copy() 269 for arg in TABLE_ARGS: 270 value = normalized_table.args.get(arg) 271 if isinstance(value, (str, exp.Identifier)): 272 normalized_table.set(arg, self._normalize_name(value)) 273 274 return normalized_table 275 276 def _normalize_name(self, name: str | exp.Identifier) -> str: 277 try: 278 identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) 279 except ParseError: 280 return name if isinstance(name, str) else name.name 281 282 return identifier.name if identifier.quoted else identifier.name.lower() 283 284 def _depth(self) -> int: 285 # The columns themselves are a mapping, but we don't want to include those 286 return super()._depth() - 1 287 288 def _ensure_table(self, table: exp.Table | str) -> exp.Table: 289 if isinstance(table, exp.Table): 290 return table 291 292 table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) 293 if not table_: 294 raise SchemaError(f"Not a valid table '{table}'") 295 296 return table_ 297 298 def _to_data_type(self, schema_type: str) -> exp.DataType: 299 """ 300 Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. 301 302 Args: 303 schema_type: the type we want to convert. 304 305 Returns: 306 The resulting expression type. 307 """ 308 if schema_type not in self._type_mapping_cache: 309 try: 310 expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) 311 if expression is None: 312 raise ValueError(f"Could not parse {schema_type}") 313 self._type_mapping_cache[schema_type] = expression # type: ignore 314 except AttributeError: 315 raise SchemaError(f"Failed to convert type {schema_type}") 316 317 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema (dict): Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
153 def __init__( 154 self, 155 schema: t.Optional[t.Dict] = None, 156 visible: t.Optional[t.Dict] = None, 157 dialect: DialectType = None, 158 ) -> None: 159 self.dialect = dialect 160 self.visible = visible or {} 161 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 162 super().__init__(self._normalize(schema or {}))
182 def add_table( 183 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 184 ) -> None: 185 """ 186 Register or update a table. Updates are only performed if a new column mapping is provided. 187 188 Args: 189 table: the `Table` expression instance or string representing the table. 190 column_mapping: a column mapping that describes the structure of the table. 191 """ 192 normalized_table = self._normalize_table(self._ensure_table(table)) 193 normalized_column_mapping = { 194 self._normalize_name(key): value 195 for key, value in ensure_column_mapping(column_mapping).items() 196 } 197 198 schema = self.find(normalized_table, raise_on_missing=False) 199 if schema and not normalized_column_mapping: 200 return 201 202 parts = self.table_parts(normalized_table) 203 204 _nested_set( 205 self.mapping, 206 tuple(reversed(parts)), 207 normalized_column_mapping, 208 ) 209 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
211 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 212 table_ = self._normalize_table(self._ensure_table(table)) 213 schema = self.find(table_) 214 215 if schema is None: 216 return [] 217 218 if not only_visible or not self.visible: 219 return list(schema) 220 221 visible = self._nested_get(self.table_parts(table_), self.visible) 222 return [col for col in schema if col in visible] # type: ignore
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
Returns:
The list of column names.
224 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 225 column_name = self._normalize_name(column if isinstance(column, str) else column.this) 226 table_ = self._normalize_table(self._ensure_table(table)) 227 228 table_schema = self.find(table_, raise_on_missing=False) 229 if table_schema: 230 column_type = table_schema.get(column_name) 231 232 if isinstance(column_type, exp.DataType): 233 return column_type 234 elif isinstance(column_type, str): 235 return self._to_data_type(column_type.upper()) 236 raise SchemaError(f"Unknown column type '{column_type}'") 237 238 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
Returns:
The resulting column type.
Inherited Members
327def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 328 if isinstance(mapping, dict): 329 return mapping 330 elif isinstance(mapping, str): 331 col_name_type_strs = [x.strip() for x in mapping.split(",")] 332 return { 333 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 334 for name_type_str in col_name_type_strs 335 } 336 # Check if mapping looks like a DataFrame StructType 337 elif hasattr(mapping, "simpleString"): 338 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore 339 elif isinstance(mapping, list): 340 return {x.strip(): None for x in mapping} 341 elif mapping is None: 342 return {} 343 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
346def flatten_schema( 347 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 348) -> t.List[t.List[str]]: 349 tables = [] 350 keys = keys or [] 351 352 for k, v in schema.items(): 353 if depth >= 2: 354 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 355 elif depth == 1: 356 tables.append(keys + [k]) 357 return tables