sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18TABLE_ARGS = ("this", "db", "catalog") 19 20T = t.TypeVar("T") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 @abc.abstractmethod 27 def add_table( 28 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 29 ) -> None: 30 """ 31 Register or update a table. Some implementing classes may require column information to also be provided. 32 33 Args: 34 table: table expression instance or string representing the table. 35 column_mapping: a column mapping that describes the structure of the table. 36 """ 37 38 @abc.abstractmethod 39 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 40 """ 41 Get the column names for a table. 42 43 Args: 44 table: the `Table` expression instance. 45 only_visible: whether to include invisible columns. 46 47 Returns: 48 The list of column names. 49 """ 50 51 @abc.abstractmethod 52 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 53 """ 54 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 55 56 Args: 57 table: the source table. 58 column: the target column. 59 60 Returns: 61 The resulting column type. 62 """ 63 64 @property 65 def supported_table_args(self) -> t.Tuple[str, ...]: 66 """ 67 Table arguments this schema support, e.g. `("this", "db", "catalog")` 68 """ 69 raise NotImplementedError 70 71 72class AbstractMappingSchema(t.Generic[T]): 73 def __init__( 74 self, 75 mapping: dict | None = None, 76 ) -> None: 77 self.mapping = mapping or {} 78 self.mapping_trie = self._build_trie(self.mapping) 79 self._supported_table_args: t.Tuple[str, ...] = tuple() 80 81 def _build_trie(self, schema: t.Dict) -> t.Dict: 82 return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth())) 83 84 def _depth(self) -> int: 85 return dict_depth(self.mapping) 86 87 @property 88 def supported_table_args(self) -> t.Tuple[str, ...]: 89 if not self._supported_table_args and self.mapping: 90 depth = self._depth() 91 92 if not depth: # None 93 self._supported_table_args = tuple() 94 elif 1 <= depth <= 3: 95 self._supported_table_args = TABLE_ARGS[:depth] 96 else: 97 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 98 99 return self._supported_table_args 100 101 def table_parts(self, table: exp.Table) -> t.List[str]: 102 if isinstance(table.this, exp.ReadCSV): 103 return [table.this.name] 104 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 105 106 def find( 107 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 108 ) -> t.Optional[T]: 109 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 110 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 111 112 if value == 0: 113 return None 114 elif value == 1: 115 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 116 if len(possibilities) == 1: 117 parts.extend(possibilities[0]) 118 else: 119 message = ", ".join(".".join(parts) for parts in possibilities) 120 if raise_on_missing: 121 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 122 return None 123 return self._nested_get(parts, raise_on_missing=raise_on_missing) 124 125 def _nested_get( 126 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 127 ) -> t.Optional[t.Any]: 128 return _nested_get( 129 d or self.mapping, 130 *zip(self.supported_table_args, reversed(parts)), 131 raise_on_missing=raise_on_missing, 132 ) 133 134 135class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 136 """ 137 Schema based on a nested mapping. 138 139 Args: 140 schema (dict): Mapping in one of the following forms: 141 1. {table: {col: type}} 142 2. {db: {table: {col: type}}} 143 3. {catalog: {db: {table: {col: type}}}} 144 4. None - Tables will be added later 145 visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns 146 are assumed to be visible. The nesting should mirror that of the schema: 147 1. {table: set(*cols)}} 148 2. {db: {table: set(*cols)}}} 149 3. {catalog: {db: {table: set(*cols)}}}} 150 dialect (str): The dialect to be used for custom type mappings. 151 """ 152 153 def __init__( 154 self, 155 schema: t.Optional[t.Dict] = None, 156 visible: t.Optional[t.Dict] = None, 157 dialect: DialectType = None, 158 ) -> None: 159 self.dialect = dialect 160 self.visible = visible or {} 161 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 162 super().__init__(self._normalize(schema or {})) 163 164 @classmethod 165 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 166 return MappingSchema( 167 schema=mapping_schema.mapping, 168 visible=mapping_schema.visible, 169 dialect=mapping_schema.dialect, 170 ) 171 172 def copy(self, **kwargs) -> MappingSchema: 173 return MappingSchema( 174 **{ # type: ignore 175 "schema": self.mapping.copy(), 176 "visible": self.visible.copy(), 177 "dialect": self.dialect, 178 **kwargs, 179 } 180 ) 181 182 def _normalize(self, schema: t.Dict) -> t.Dict: 183 """ 184 Converts all identifiers in the schema into lowercase, unless they're quoted. 185 186 Args: 187 schema: the schema to normalize. 188 189 Returns: 190 The normalized schema mapping. 191 """ 192 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 193 194 normalized_mapping: t.Dict = {} 195 for keys in flattened_schema: 196 columns = _nested_get(schema, *zip(keys, keys)) 197 assert columns is not None 198 199 normalized_keys = [self._normalize_name(key) for key in keys] 200 for column_name, column_type in columns.items(): 201 _nested_set( 202 normalized_mapping, 203 normalized_keys + [self._normalize_name(column_name)], 204 column_type, 205 ) 206 207 return normalized_mapping 208 209 def add_table( 210 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 211 ) -> None: 212 """ 213 Register or update a table. Updates are only performed if a new column mapping is provided. 214 215 Args: 216 table: the `Table` expression instance or string representing the table. 217 column_mapping: a column mapping that describes the structure of the table. 218 """ 219 table_ = self._ensure_table(table) 220 column_mapping = ensure_column_mapping(column_mapping) 221 schema = self.find(table_, raise_on_missing=False) 222 223 if schema and not column_mapping: 224 return 225 226 _nested_set( 227 self.mapping, 228 list(reversed(self.table_parts(table_))), 229 column_mapping, 230 ) 231 self.mapping_trie = self._build_trie(self.mapping) 232 233 def _normalize_name(self, name: str) -> str: 234 try: 235 identifier: t.Optional[exp.Expression] = sqlglot.parse_one( 236 name, read=self.dialect, into=exp.Identifier 237 ) 238 except: 239 identifier = exp.to_identifier(name) 240 assert isinstance(identifier, exp.Identifier) 241 242 if identifier.quoted: 243 return identifier.name 244 return identifier.name.lower() 245 246 def _depth(self) -> int: 247 # The columns themselves are a mapping, but we don't want to include those 248 return super()._depth() - 1 249 250 def _ensure_table(self, table: exp.Table | str) -> exp.Table: 251 table_ = exp.to_table(table) 252 253 if not table_: 254 raise SchemaError(f"Not a valid table '{table}'") 255 256 return table_ 257 258 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 259 table_ = self._ensure_table(table) 260 schema = self.find(table_) 261 262 if schema is None: 263 return [] 264 265 if not only_visible or not self.visible: 266 return list(schema) 267 268 visible = self._nested_get(self.table_parts(table_), self.visible) 269 return [col for col in schema if col in visible] # type: ignore 270 271 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 272 column_name = column if isinstance(column, str) else column.name 273 table_ = exp.to_table(table) 274 if table_: 275 table_schema = self.find(table_, raise_on_missing=False) 276 if table_schema: 277 column_type = table_schema.get(column_name) 278 279 if isinstance(column_type, exp.DataType): 280 return column_type 281 elif isinstance(column_type, str): 282 return self._to_data_type(column_type.upper()) 283 raise SchemaError(f"Unknown column type '{column_type}'") 284 return exp.DataType(this=exp.DataType.Type.UNKNOWN) 285 raise SchemaError(f"Could not convert table '{table}'") 286 287 def _to_data_type(self, schema_type: str) -> exp.DataType: 288 """ 289 Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. 290 291 Args: 292 schema_type: the type we want to convert. 293 294 Returns: 295 The resulting expression type. 296 """ 297 if schema_type not in self._type_mapping_cache: 298 try: 299 expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) 300 if expression is None: 301 raise ValueError(f"Could not parse {schema_type}") 302 self._type_mapping_cache[schema_type] = expression # type: ignore 303 except AttributeError: 304 raise SchemaError(f"Failed to convert type {schema_type}") 305 306 return self._type_mapping_cache[schema_type] 307 308 309def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: 310 if isinstance(schema, Schema): 311 return schema 312 313 return MappingSchema(schema, dialect=dialect) 314 315 316def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): 317 if isinstance(mapping, dict): 318 return mapping 319 elif isinstance(mapping, str): 320 col_name_type_strs = [x.strip() for x in mapping.split(",")] 321 return { 322 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 323 for name_type_str in col_name_type_strs 324 } 325 # Check if mapping looks like a DataFrame StructType 326 elif hasattr(mapping, "simpleString"): 327 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore 328 elif isinstance(mapping, list): 329 return {x.strip(): None for x in mapping} 330 elif mapping is None: 331 return {} 332 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 333 334 335def flatten_schema( 336 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 337) -> t.List[t.List[str]]: 338 tables = [] 339 keys = keys or [] 340 341 for k, v in schema.items(): 342 if depth >= 2: 343 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 344 elif depth == 1: 345 tables.append(keys + [k]) 346 return tables 347 348 349def _nested_get( 350 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 351) -> t.Optional[t.Any]: 352 """ 353 Get a value for a nested dictionary. 354 355 Args: 356 d: the dictionary to search. 357 *path: tuples of (name, key), where: 358 `key` is the key in the dictionary to get. 359 `name` is a string to use in the error if `key` isn't found. 360 361 Returns: 362 The value or None if it doesn't exist. 363 """ 364 for name, key in path: 365 d = d.get(key) # type: ignore 366 if d is None: 367 if raise_on_missing: 368 name = "table" if name == "this" else name 369 raise ValueError(f"Unknown {name}: {key}") 370 return None 371 return d 372 373 374def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: 375 """ 376 In-place set a value for a nested dictionary 377 378 Example: 379 >>> _nested_set({}, ["top_key", "second_key"], "value") 380 {'top_key': {'second_key': 'value'}} 381 382 >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 383 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 384 385 Args: 386 d: dictionary to update. 387 keys: the keys that makeup the path to `value`. 388 value: the value to set in the dictionary for the given key path. 389 390 Returns: 391 The (possibly) updated dictionary. 392 """ 393 if not keys: 394 return d 395 396 if len(keys) == 1: 397 d[keys[0]] = value 398 return d 399 400 subd = d 401 for key in keys[:-1]: 402 if key not in subd: 403 subd = subd.setdefault(key, {}) 404 else: 405 subd = subd[key] 406 407 subd[keys[-1]] = value 408 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 @abc.abstractmethod 28 def add_table( 29 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 30 ) -> None: 31 """ 32 Register or update a table. Some implementing classes may require column information to also be provided. 33 34 Args: 35 table: table expression instance or string representing the table. 36 column_mapping: a column mapping that describes the structure of the table. 37 """ 38 39 @abc.abstractmethod 40 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 41 """ 42 Get the column names for a table. 43 44 Args: 45 table: the `Table` expression instance. 46 only_visible: whether to include invisible columns. 47 48 Returns: 49 The list of column names. 50 """ 51 52 @abc.abstractmethod 53 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 54 """ 55 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 56 57 Args: 58 table: the source table. 59 column: the target column. 60 61 Returns: 62 The resulting column type. 63 """ 64 65 @property 66 def supported_table_args(self) -> t.Tuple[str, ...]: 67 """ 68 Table arguments this schema support, e.g. `("this", "db", "catalog")` 69 """ 70 raise NotImplementedError
Abstract base class for database schemas
27 @abc.abstractmethod 28 def add_table( 29 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 30 ) -> None: 31 """ 32 Register or update a table. Some implementing classes may require column information to also be provided. 33 34 Args: 35 table: table expression instance or string representing the table. 36 column_mapping: a column mapping that describes the structure of the table. 37 """
Register or update a table. Some implementing classes may require column information to also be provided.
Arguments:
- table: table expression instance or string representing the table.
- column_mapping: a column mapping that describes the structure of the table.
39 @abc.abstractmethod 40 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 41 """ 42 Get the column names for a table. 43 44 Args: 45 table: the `Table` expression instance. 46 only_visible: whether to include invisible columns. 47 48 Returns: 49 The list of column names. 50 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
Returns:
The list of column names.
52 @abc.abstractmethod 53 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 54 """ 55 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 56 57 Args: 58 table: the source table. 59 column: the target column. 60 61 Returns: 62 The resulting column type. 63 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
Returns:
The resulting column type.
73class AbstractMappingSchema(t.Generic[T]): 74 def __init__( 75 self, 76 mapping: dict | None = None, 77 ) -> None: 78 self.mapping = mapping or {} 79 self.mapping_trie = self._build_trie(self.mapping) 80 self._supported_table_args: t.Tuple[str, ...] = tuple() 81 82 def _build_trie(self, schema: t.Dict) -> t.Dict: 83 return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth())) 84 85 def _depth(self) -> int: 86 return dict_depth(self.mapping) 87 88 @property 89 def supported_table_args(self) -> t.Tuple[str, ...]: 90 if not self._supported_table_args and self.mapping: 91 depth = self._depth() 92 93 if not depth: # None 94 self._supported_table_args = tuple() 95 elif 1 <= depth <= 3: 96 self._supported_table_args = TABLE_ARGS[:depth] 97 else: 98 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 99 100 return self._supported_table_args 101 102 def table_parts(self, table: exp.Table) -> t.List[str]: 103 if isinstance(table.this, exp.ReadCSV): 104 return [table.this.name] 105 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 106 107 def find( 108 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 109 ) -> t.Optional[T]: 110 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 111 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 112 113 if value == 0: 114 return None 115 elif value == 1: 116 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 117 if len(possibilities) == 1: 118 parts.extend(possibilities[0]) 119 else: 120 message = ", ".join(".".join(parts) for parts in possibilities) 121 if raise_on_missing: 122 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 123 return None 124 return self._nested_get(parts, raise_on_missing=raise_on_missing) 125 126 def _nested_get( 127 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 128 ) -> t.Optional[t.Any]: 129 return _nested_get( 130 d or self.mapping, 131 *zip(self.supported_table_args, reversed(parts)), 132 raise_on_missing=raise_on_missing, 133 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
107 def find( 108 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 109 ) -> t.Optional[T]: 110 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 111 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 112 113 if value == 0: 114 return None 115 elif value == 1: 116 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 117 if len(possibilities) == 1: 118 parts.extend(possibilities[0]) 119 else: 120 message = ", ".join(".".join(parts) for parts in possibilities) 121 if raise_on_missing: 122 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 123 return None 124 return self._nested_get(parts, raise_on_missing=raise_on_missing)
136class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 137 """ 138 Schema based on a nested mapping. 139 140 Args: 141 schema (dict): Mapping in one of the following forms: 142 1. {table: {col: type}} 143 2. {db: {table: {col: type}}} 144 3. {catalog: {db: {table: {col: type}}}} 145 4. None - Tables will be added later 146 visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns 147 are assumed to be visible. The nesting should mirror that of the schema: 148 1. {table: set(*cols)}} 149 2. {db: {table: set(*cols)}}} 150 3. {catalog: {db: {table: set(*cols)}}}} 151 dialect (str): The dialect to be used for custom type mappings. 152 """ 153 154 def __init__( 155 self, 156 schema: t.Optional[t.Dict] = None, 157 visible: t.Optional[t.Dict] = None, 158 dialect: DialectType = None, 159 ) -> None: 160 self.dialect = dialect 161 self.visible = visible or {} 162 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 163 super().__init__(self._normalize(schema or {})) 164 165 @classmethod 166 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 167 return MappingSchema( 168 schema=mapping_schema.mapping, 169 visible=mapping_schema.visible, 170 dialect=mapping_schema.dialect, 171 ) 172 173 def copy(self, **kwargs) -> MappingSchema: 174 return MappingSchema( 175 **{ # type: ignore 176 "schema": self.mapping.copy(), 177 "visible": self.visible.copy(), 178 "dialect": self.dialect, 179 **kwargs, 180 } 181 ) 182 183 def _normalize(self, schema: t.Dict) -> t.Dict: 184 """ 185 Converts all identifiers in the schema into lowercase, unless they're quoted. 186 187 Args: 188 schema: the schema to normalize. 189 190 Returns: 191 The normalized schema mapping. 192 """ 193 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 194 195 normalized_mapping: t.Dict = {} 196 for keys in flattened_schema: 197 columns = _nested_get(schema, *zip(keys, keys)) 198 assert columns is not None 199 200 normalized_keys = [self._normalize_name(key) for key in keys] 201 for column_name, column_type in columns.items(): 202 _nested_set( 203 normalized_mapping, 204 normalized_keys + [self._normalize_name(column_name)], 205 column_type, 206 ) 207 208 return normalized_mapping 209 210 def add_table( 211 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 212 ) -> None: 213 """ 214 Register or update a table. Updates are only performed if a new column mapping is provided. 215 216 Args: 217 table: the `Table` expression instance or string representing the table. 218 column_mapping: a column mapping that describes the structure of the table. 219 """ 220 table_ = self._ensure_table(table) 221 column_mapping = ensure_column_mapping(column_mapping) 222 schema = self.find(table_, raise_on_missing=False) 223 224 if schema and not column_mapping: 225 return 226 227 _nested_set( 228 self.mapping, 229 list(reversed(self.table_parts(table_))), 230 column_mapping, 231 ) 232 self.mapping_trie = self._build_trie(self.mapping) 233 234 def _normalize_name(self, name: str) -> str: 235 try: 236 identifier: t.Optional[exp.Expression] = sqlglot.parse_one( 237 name, read=self.dialect, into=exp.Identifier 238 ) 239 except: 240 identifier = exp.to_identifier(name) 241 assert isinstance(identifier, exp.Identifier) 242 243 if identifier.quoted: 244 return identifier.name 245 return identifier.name.lower() 246 247 def _depth(self) -> int: 248 # The columns themselves are a mapping, but we don't want to include those 249 return super()._depth() - 1 250 251 def _ensure_table(self, table: exp.Table | str) -> exp.Table: 252 table_ = exp.to_table(table) 253 254 if not table_: 255 raise SchemaError(f"Not a valid table '{table}'") 256 257 return table_ 258 259 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 260 table_ = self._ensure_table(table) 261 schema = self.find(table_) 262 263 if schema is None: 264 return [] 265 266 if not only_visible or not self.visible: 267 return list(schema) 268 269 visible = self._nested_get(self.table_parts(table_), self.visible) 270 return [col for col in schema if col in visible] # type: ignore 271 272 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 273 column_name = column if isinstance(column, str) else column.name 274 table_ = exp.to_table(table) 275 if table_: 276 table_schema = self.find(table_, raise_on_missing=False) 277 if table_schema: 278 column_type = table_schema.get(column_name) 279 280 if isinstance(column_type, exp.DataType): 281 return column_type 282 elif isinstance(column_type, str): 283 return self._to_data_type(column_type.upper()) 284 raise SchemaError(f"Unknown column type '{column_type}'") 285 return exp.DataType(this=exp.DataType.Type.UNKNOWN) 286 raise SchemaError(f"Could not convert table '{table}'") 287 288 def _to_data_type(self, schema_type: str) -> exp.DataType: 289 """ 290 Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. 291 292 Args: 293 schema_type: the type we want to convert. 294 295 Returns: 296 The resulting expression type. 297 """ 298 if schema_type not in self._type_mapping_cache: 299 try: 300 expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) 301 if expression is None: 302 raise ValueError(f"Could not parse {schema_type}") 303 self._type_mapping_cache[schema_type] = expression # type: ignore 304 except AttributeError: 305 raise SchemaError(f"Failed to convert type {schema_type}") 306 307 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema (dict): Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
154 def __init__( 155 self, 156 schema: t.Optional[t.Dict] = None, 157 visible: t.Optional[t.Dict] = None, 158 dialect: DialectType = None, 159 ) -> None: 160 self.dialect = dialect 161 self.visible = visible or {} 162 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 163 super().__init__(self._normalize(schema or {}))
210 def add_table( 211 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 212 ) -> None: 213 """ 214 Register or update a table. Updates are only performed if a new column mapping is provided. 215 216 Args: 217 table: the `Table` expression instance or string representing the table. 218 column_mapping: a column mapping that describes the structure of the table. 219 """ 220 table_ = self._ensure_table(table) 221 column_mapping = ensure_column_mapping(column_mapping) 222 schema = self.find(table_, raise_on_missing=False) 223 224 if schema and not column_mapping: 225 return 226 227 _nested_set( 228 self.mapping, 229 list(reversed(self.table_parts(table_))), 230 column_mapping, 231 ) 232 self.mapping_trie = self._build_trie(self.mapping)
Register or update a table. Updates are only performed if a new column mapping is provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
259 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 260 table_ = self._ensure_table(table) 261 schema = self.find(table_) 262 263 if schema is None: 264 return [] 265 266 if not only_visible or not self.visible: 267 return list(schema) 268 269 visible = self._nested_get(self.table_parts(table_), self.visible) 270 return [col for col in schema if col in visible] # type: ignore
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
Returns:
The list of column names.
272 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 273 column_name = column if isinstance(column, str) else column.name 274 table_ = exp.to_table(table) 275 if table_: 276 table_schema = self.find(table_, raise_on_missing=False) 277 if table_schema: 278 column_type = table_schema.get(column_name) 279 280 if isinstance(column_type, exp.DataType): 281 return column_type 282 elif isinstance(column_type, str): 283 return self._to_data_type(column_type.upper()) 284 raise SchemaError(f"Unknown column type '{column_type}'") 285 return exp.DataType(this=exp.DataType.Type.UNKNOWN) 286 raise SchemaError(f"Could not convert table '{table}'")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
Returns:
The resulting column type.
Inherited Members
317def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): 318 if isinstance(mapping, dict): 319 return mapping 320 elif isinstance(mapping, str): 321 col_name_type_strs = [x.strip() for x in mapping.split(",")] 322 return { 323 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 324 for name_type_str in col_name_type_strs 325 } 326 # Check if mapping looks like a DataFrame StructType 327 elif hasattr(mapping, "simpleString"): 328 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore 329 elif isinstance(mapping, list): 330 return {x.strip(): None for x in mapping} 331 elif mapping is None: 332 return {} 333 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
336def flatten_schema( 337 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 338) -> t.List[t.List[str]]: 339 tables = [] 340 keys = keys or [] 341 342 for k, v in schema.items(): 343 if depth >= 2: 344 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 345 elif depth == 1: 346 tables.append(keys + [k]) 347 return tables