Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import Dialect
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import TrieResult, in_trie, new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 19
 20TABLE_ARGS = ("this", "db", "catalog")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34    ) -> None:
 35        """
 36        Register or update a table. Some implementing classes may require column information to also be provided.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50    ) -> t.List[str]:
 51        """
 52        Get the column names for a table.
 53
 54        Args:
 55            table: the `Table` expression instance.
 56            only_visible: whether to include invisible columns.
 57            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 58
 59        Returns:
 60            The list of column names.
 61        """
 62
 63    @abc.abstractmethod
 64    def get_column_type(
 65        self,
 66        table: exp.Table | str,
 67        column: exp.Column,
 68        dialect: DialectType = None,
 69    ) -> exp.DataType:
 70        """
 71        Get the `sqlglot.exp.DataType` type of a column in the schema.
 72
 73        Args:
 74            table: the source table.
 75            column: the target column.
 76            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 77
 78        Returns:
 79            The resulting column type.
 80        """
 81
 82    @property
 83    @abc.abstractmethod
 84    def supported_table_args(self) -> t.Tuple[str, ...]:
 85        """
 86        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 87        """
 88
 89    @property
 90    def empty(self) -> bool:
 91        """Returns whether or not the schema is empty."""
 92        return True
 93
 94
 95class AbstractMappingSchema(t.Generic[T]):
 96    def __init__(
 97        self,
 98        mapping: t.Optional[t.Dict] = None,
 99    ) -> None:
100        self.mapping = mapping or {}
101        self.mapping_trie = new_trie(
102            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
103        )
104        self._supported_table_args: t.Tuple[str, ...] = tuple()
105
106    @property
107    def empty(self) -> bool:
108        return not self.mapping
109
110    def _depth(self) -> int:
111        return dict_depth(self.mapping)
112
113    @property
114    def supported_table_args(self) -> t.Tuple[str, ...]:
115        if not self._supported_table_args and self.mapping:
116            depth = self._depth()
117
118            if not depth:  # None
119                self._supported_table_args = tuple()
120            elif 1 <= depth <= 3:
121                self._supported_table_args = TABLE_ARGS[:depth]
122            else:
123                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
124
125        return self._supported_table_args
126
127    def table_parts(self, table: exp.Table) -> t.List[str]:
128        if isinstance(table.this, exp.ReadCSV):
129            return [table.this.name]
130        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
131
132    def find(
133        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
134    ) -> t.Optional[T]:
135        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
136        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
137
138        if value == TrieResult.FAILED:
139            return None
140
141        if value == TrieResult.PREFIX:
142            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
143
144            if len(possibilities) == 1:
145                parts.extend(possibilities[0])
146            else:
147                message = ", ".join(".".join(parts) for parts in possibilities)
148                if raise_on_missing:
149                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
150                return None
151
152        return self.nested_get(parts, raise_on_missing=raise_on_missing)
153
154    def nested_get(
155        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
156    ) -> t.Optional[t.Any]:
157        return nested_get(
158            d or self.mapping,
159            *zip(self.supported_table_args, reversed(parts)),
160            raise_on_missing=raise_on_missing,
161        )
162
163
164class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
165    """
166    Schema based on a nested mapping.
167
168    Args:
169        schema: Mapping in one of the following forms:
170            1. {table: {col: type}}
171            2. {db: {table: {col: type}}}
172            3. {catalog: {db: {table: {col: type}}}}
173            4. None - Tables will be added later
174        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
175            are assumed to be visible. The nesting should mirror that of the schema:
176            1. {table: set(*cols)}}
177            2. {db: {table: set(*cols)}}}
178            3. {catalog: {db: {table: set(*cols)}}}}
179        dialect: The dialect to be used for custom type mappings & parsing string arguments.
180        normalize: Whether to normalize identifier names according to the given dialect or not.
181    """
182
183    def __init__(
184        self,
185        schema: t.Optional[t.Dict] = None,
186        visible: t.Optional[t.Dict] = None,
187        dialect: DialectType = None,
188        normalize: bool = True,
189    ) -> None:
190        self.dialect = dialect
191        self.visible = visible or {}
192        self.normalize = normalize
193        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
194
195        super().__init__(self._normalize(schema or {}))
196
197    @classmethod
198    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
199        return MappingSchema(
200            schema=mapping_schema.mapping,
201            visible=mapping_schema.visible,
202            dialect=mapping_schema.dialect,
203        )
204
205    def copy(self, **kwargs) -> MappingSchema:
206        return MappingSchema(
207            **{  # type: ignore
208                "schema": self.mapping.copy(),
209                "visible": self.visible.copy(),
210                "dialect": self.dialect,
211                **kwargs,
212            }
213        )
214
215    def add_table(
216        self,
217        table: exp.Table | str,
218        column_mapping: t.Optional[ColumnMapping] = None,
219        dialect: DialectType = None,
220    ) -> None:
221        """
222        Register or update a table. Updates are only performed if a new column mapping is provided.
223
224        Args:
225            table: the `Table` expression instance or string representing the table.
226            column_mapping: a column mapping that describes the structure of the table.
227            dialect: the SQL dialect that will be used to parse `table` if it's a string.
228        """
229        normalized_table = self._normalize_table(table, dialect=dialect)
230
231        normalized_column_mapping = {
232            self._normalize_name(key, dialect=dialect): value
233            for key, value in ensure_column_mapping(column_mapping).items()
234        }
235
236        schema = self.find(normalized_table, raise_on_missing=False)
237        if schema and not normalized_column_mapping:
238            return
239
240        parts = self.table_parts(normalized_table)
241
242        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
243        new_trie([parts], self.mapping_trie)
244
245    def column_names(
246        self,
247        table: exp.Table | str,
248        only_visible: bool = False,
249        dialect: DialectType = None,
250    ) -> t.List[str]:
251        normalized_table = self._normalize_table(table, dialect=dialect)
252
253        schema = self.find(normalized_table)
254        if schema is None:
255            return []
256
257        if not only_visible or not self.visible:
258            return list(schema)
259
260        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
261        return [col for col in schema if col in visible]
262
263    def get_column_type(
264        self,
265        table: exp.Table | str,
266        column: exp.Column,
267        dialect: DialectType = None,
268    ) -> exp.DataType:
269        normalized_table = self._normalize_table(table, dialect=dialect)
270
271        normalized_column_name = self._normalize_name(
272            column if isinstance(column, str) else column.this, dialect=dialect
273        )
274
275        table_schema = self.find(normalized_table, raise_on_missing=False)
276        if table_schema:
277            column_type = table_schema.get(normalized_column_name)
278
279            if isinstance(column_type, exp.DataType):
280                return column_type
281            elif isinstance(column_type, str):
282                return self._to_data_type(column_type.upper(), dialect=dialect)
283
284        return exp.DataType.build("unknown")
285
286    def _normalize(self, schema: t.Dict) -> t.Dict:
287        """
288        Normalizes all identifiers in the schema.
289
290        Args:
291            schema: the schema to normalize.
292
293        Returns:
294            The normalized schema mapping.
295        """
296        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
297
298        normalized_mapping: t.Dict = {}
299        for keys in flattened_schema:
300            columns = nested_get(schema, *zip(keys, keys))
301            assert columns is not None
302
303            normalized_keys = [
304                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
305            ]
306            for column_name, column_type in columns.items():
307                nested_set(
308                    normalized_mapping,
309                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
310                    column_type,
311                )
312
313        return normalized_mapping
314
315    def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
316        normalized_table = exp.maybe_parse(
317            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
318        )
319
320        for arg in TABLE_ARGS:
321            value = normalized_table.args.get(arg)
322            if isinstance(value, (str, exp.Identifier)):
323                normalized_table.set(
324                    arg,
325                    exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
326                )
327
328        return normalized_table
329
330    def _normalize_name(
331        self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
332    ) -> str:
333        dialect = dialect or self.dialect
334
335        try:
336            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
337        except ParseError:
338            return name if isinstance(name, str) else name.name
339
340        name = identifier.name
341        if not self.normalize:
342            return name
343
344        # This can be useful for normalize_identifier
345        identifier.meta["is_table"] = is_table
346        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
347
348    def _depth(self) -> int:
349        # The columns themselves are a mapping, but we don't want to include those
350        return super()._depth() - 1
351
352    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
353        """
354        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
355
356        Args:
357            schema_type: the type we want to convert.
358            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
359
360        Returns:
361            The resulting expression type.
362        """
363        if schema_type not in self._type_mapping_cache:
364            dialect = dialect or self.dialect
365
366            try:
367                expression = exp.DataType.build(schema_type, dialect=dialect)
368                self._type_mapping_cache[schema_type] = expression
369            except AttributeError:
370                in_dialect = f" in dialect {dialect}" if dialect else ""
371                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
372
373        return self._type_mapping_cache[schema_type]
374
375
376def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
377    if isinstance(schema, Schema):
378        return schema
379
380    return MappingSchema(schema, **kwargs)
381
382
383def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
384    if mapping is None:
385        return {}
386    elif isinstance(mapping, dict):
387        return mapping
388    elif isinstance(mapping, str):
389        col_name_type_strs = [x.strip() for x in mapping.split(",")]
390        return {
391            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
392            for name_type_str in col_name_type_strs
393        }
394    # Check if mapping looks like a DataFrame StructType
395    elif hasattr(mapping, "simpleString"):
396        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
397    elif isinstance(mapping, list):
398        return {x.strip(): None for x in mapping}
399
400    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
401
402
403def flatten_schema(
404    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
405) -> t.List[t.List[str]]:
406    tables = []
407    keys = keys or []
408
409    for k, v in schema.items():
410        if depth >= 2:
411            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
412        elif depth == 1:
413            tables.append(keys + [k])
414
415    return tables
416
417
418def nested_get(
419    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
420) -> t.Optional[t.Any]:
421    """
422    Get a value for a nested dictionary.
423
424    Args:
425        d: the dictionary to search.
426        *path: tuples of (name, key), where:
427            `key` is the key in the dictionary to get.
428            `name` is a string to use in the error if `key` isn't found.
429
430    Returns:
431        The value or None if it doesn't exist.
432    """
433    for name, key in path:
434        d = d.get(key)  # type: ignore
435        if d is None:
436            if raise_on_missing:
437                name = "table" if name == "this" else name
438                raise ValueError(f"Unknown {name}: {key}")
439            return None
440
441    return d
442
443
444def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
445    """
446    In-place set a value for a nested dictionary
447
448    Example:
449        >>> nested_set({}, ["top_key", "second_key"], "value")
450        {'top_key': {'second_key': 'value'}}
451
452        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
453        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
454
455    Args:
456        d: dictionary to update.
457        keys: the keys that makeup the path to `value`.
458        value: the value to set in the dictionary for the given key path.
459
460    Returns:
461        The (possibly) updated dictionary.
462    """
463    if not keys:
464        return d
465
466    if len(keys) == 1:
467        d[keys[0]] = value
468        return d
469
470    subd = d
471    for key in keys[:-1]:
472        if key not in subd:
473            subd = subd.setdefault(key, {})
474        else:
475            subd = subd[key]
476
477    subd[keys[-1]] = value
478    return d
TABLE_ARGS = ('this', 'db', 'catalog')
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    dialect: DialectType
28
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38
39        Args:
40            table: the `Table` expression instance or string representing the table.
41            column_mapping: a column mapping that describes the structure of the table.
42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
43        """
44
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51    ) -> t.List[str]:
52        """
53        Get the column names for a table.
54
55        Args:
56            table: the `Table` expression instance.
57            only_visible: whether to include invisible columns.
58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
59
60        Returns:
61            The list of column names.
62        """
63
64    @abc.abstractmethod
65    def get_column_type(
66        self,
67        table: exp.Table | str,
68        column: exp.Column,
69        dialect: DialectType = None,
70    ) -> exp.DataType:
71        """
72        Get the `sqlglot.exp.DataType` type of a column in the schema.
73
74        Args:
75            table: the source table.
76            column: the target column.
77            dialect: the SQL dialect that will be used to parse `table` if it's a string.
78
79        Returns:
80            The resulting column type.
81        """
82
83    @property
84    @abc.abstractmethod
85    def supported_table_args(self) -> t.Tuple[str, ...]:
86        """
87        Table arguments this schema support, e.g. `("this", "db", "catalog")`
88        """
89
90    @property
91    def empty(self) -> bool:
92        """Returns whether or not the schema is empty."""
93        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38
39        Args:
40            table: the `Table` expression instance or string representing the table.
41            column_mapping: a column mapping that describes the structure of the table.
42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
43        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51    ) -> t.List[str]:
52        """
53        Get the column names for a table.
54
55        Args:
56            table: the `Table` expression instance.
57            only_visible: whether to include invisible columns.
58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
59
60        Returns:
61            The list of column names.
62        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
64    @abc.abstractmethod
65    def get_column_type(
66        self,
67        table: exp.Table | str,
68        column: exp.Column,
69        dialect: DialectType = None,
70    ) -> exp.DataType:
71        """
72        Get the `sqlglot.exp.DataType` type of a column in the schema.
73
74        Args:
75            table: the source table.
76            column: the target column.
77            dialect: the SQL dialect that will be used to parse `table` if it's a string.
78
79        Returns:
80            The resulting column type.
81        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
 96class AbstractMappingSchema(t.Generic[T]):
 97    def __init__(
 98        self,
 99        mapping: t.Optional[t.Dict] = None,
100    ) -> None:
101        self.mapping = mapping or {}
102        self.mapping_trie = new_trie(
103            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
104        )
105        self._supported_table_args: t.Tuple[str, ...] = tuple()
106
107    @property
108    def empty(self) -> bool:
109        return not self.mapping
110
111    def _depth(self) -> int:
112        return dict_depth(self.mapping)
113
114    @property
115    def supported_table_args(self) -> t.Tuple[str, ...]:
116        if not self._supported_table_args and self.mapping:
117            depth = self._depth()
118
119            if not depth:  # None
120                self._supported_table_args = tuple()
121            elif 1 <= depth <= 3:
122                self._supported_table_args = TABLE_ARGS[:depth]
123            else:
124                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
125
126        return self._supported_table_args
127
128    def table_parts(self, table: exp.Table) -> t.List[str]:
129        if isinstance(table.this, exp.ReadCSV):
130            return [table.this.name]
131        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
132
133    def find(
134        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
135    ) -> t.Optional[T]:
136        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
137        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138
139        if value == TrieResult.FAILED:
140            return None
141
142        if value == TrieResult.PREFIX:
143            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144
145            if len(possibilities) == 1:
146                parts.extend(possibilities[0])
147            else:
148                message = ", ".join(".".join(parts) for parts in possibilities)
149                if raise_on_missing:
150                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
151                return None
152
153        return self.nested_get(parts, raise_on_missing=raise_on_missing)
154
155    def nested_get(
156        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
157    ) -> t.Optional[t.Any]:
158        return nested_get(
159            d or self.mapping,
160            *zip(self.supported_table_args, reversed(parts)),
161            raise_on_missing=raise_on_missing,
162        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
 97    def __init__(
 98        self,
 99        mapping: t.Optional[t.Dict] = None,
100    ) -> None:
101        self.mapping = mapping or {}
102        self.mapping_trie = new_trie(
103            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
104        )
105        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
128    def table_parts(self, table: exp.Table) -> t.List[str]:
129        if isinstance(table.this, exp.ReadCSV):
130            return [table.this.name]
131        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
133    def find(
134        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
135    ) -> t.Optional[T]:
136        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
137        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138
139        if value == TrieResult.FAILED:
140            return None
141
142        if value == TrieResult.PREFIX:
143            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144
145            if len(possibilities) == 1:
146                parts.extend(possibilities[0])
147            else:
148                message = ", ".join(".".join(parts) for parts in possibilities)
149                if raise_on_missing:
150                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
151                return None
152
153        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
155    def nested_get(
156        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
157    ) -> t.Optional[t.Any]:
158        return nested_get(
159            d or self.mapping,
160            *zip(self.supported_table_args, reversed(parts)),
161            raise_on_missing=raise_on_missing,
162        )
165class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
166    """
167    Schema based on a nested mapping.
168
169    Args:
170        schema: Mapping in one of the following forms:
171            1. {table: {col: type}}
172            2. {db: {table: {col: type}}}
173            3. {catalog: {db: {table: {col: type}}}}
174            4. None - Tables will be added later
175        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
176            are assumed to be visible. The nesting should mirror that of the schema:
177            1. {table: set(*cols)}}
178            2. {db: {table: set(*cols)}}}
179            3. {catalog: {db: {table: set(*cols)}}}}
180        dialect: The dialect to be used for custom type mappings & parsing string arguments.
181        normalize: Whether to normalize identifier names according to the given dialect or not.
182    """
183
184    def __init__(
185        self,
186        schema: t.Optional[t.Dict] = None,
187        visible: t.Optional[t.Dict] = None,
188        dialect: DialectType = None,
189        normalize: bool = True,
190    ) -> None:
191        self.dialect = dialect
192        self.visible = visible or {}
193        self.normalize = normalize
194        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195
196        super().__init__(self._normalize(schema or {}))
197
198    @classmethod
199    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
200        return MappingSchema(
201            schema=mapping_schema.mapping,
202            visible=mapping_schema.visible,
203            dialect=mapping_schema.dialect,
204        )
205
206    def copy(self, **kwargs) -> MappingSchema:
207        return MappingSchema(
208            **{  # type: ignore
209                "schema": self.mapping.copy(),
210                "visible": self.visible.copy(),
211                "dialect": self.dialect,
212                **kwargs,
213            }
214        )
215
216    def add_table(
217        self,
218        table: exp.Table | str,
219        column_mapping: t.Optional[ColumnMapping] = None,
220        dialect: DialectType = None,
221    ) -> None:
222        """
223        Register or update a table. Updates are only performed if a new column mapping is provided.
224
225        Args:
226            table: the `Table` expression instance or string representing the table.
227            column_mapping: a column mapping that describes the structure of the table.
228            dialect: the SQL dialect that will be used to parse `table` if it's a string.
229        """
230        normalized_table = self._normalize_table(table, dialect=dialect)
231
232        normalized_column_mapping = {
233            self._normalize_name(key, dialect=dialect): value
234            for key, value in ensure_column_mapping(column_mapping).items()
235        }
236
237        schema = self.find(normalized_table, raise_on_missing=False)
238        if schema and not normalized_column_mapping:
239            return
240
241        parts = self.table_parts(normalized_table)
242
243        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
244        new_trie([parts], self.mapping_trie)
245
246    def column_names(
247        self,
248        table: exp.Table | str,
249        only_visible: bool = False,
250        dialect: DialectType = None,
251    ) -> t.List[str]:
252        normalized_table = self._normalize_table(table, dialect=dialect)
253
254        schema = self.find(normalized_table)
255        if schema is None:
256            return []
257
258        if not only_visible or not self.visible:
259            return list(schema)
260
261        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
262        return [col for col in schema if col in visible]
263
264    def get_column_type(
265        self,
266        table: exp.Table | str,
267        column: exp.Column,
268        dialect: DialectType = None,
269    ) -> exp.DataType:
270        normalized_table = self._normalize_table(table, dialect=dialect)
271
272        normalized_column_name = self._normalize_name(
273            column if isinstance(column, str) else column.this, dialect=dialect
274        )
275
276        table_schema = self.find(normalized_table, raise_on_missing=False)
277        if table_schema:
278            column_type = table_schema.get(normalized_column_name)
279
280            if isinstance(column_type, exp.DataType):
281                return column_type
282            elif isinstance(column_type, str):
283                return self._to_data_type(column_type.upper(), dialect=dialect)
284
285        return exp.DataType.build("unknown")
286
287    def _normalize(self, schema: t.Dict) -> t.Dict:
288        """
289        Normalizes all identifiers in the schema.
290
291        Args:
292            schema: the schema to normalize.
293
294        Returns:
295            The normalized schema mapping.
296        """
297        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
298
299        normalized_mapping: t.Dict = {}
300        for keys in flattened_schema:
301            columns = nested_get(schema, *zip(keys, keys))
302            assert columns is not None
303
304            normalized_keys = [
305                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
306            ]
307            for column_name, column_type in columns.items():
308                nested_set(
309                    normalized_mapping,
310                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
311                    column_type,
312                )
313
314        return normalized_mapping
315
316    def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
317        normalized_table = exp.maybe_parse(
318            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
319        )
320
321        for arg in TABLE_ARGS:
322            value = normalized_table.args.get(arg)
323            if isinstance(value, (str, exp.Identifier)):
324                normalized_table.set(
325                    arg,
326                    exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
327                )
328
329        return normalized_table
330
331    def _normalize_name(
332        self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
333    ) -> str:
334        dialect = dialect or self.dialect
335
336        try:
337            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
338        except ParseError:
339            return name if isinstance(name, str) else name.name
340
341        name = identifier.name
342        if not self.normalize:
343            return name
344
345        # This can be useful for normalize_identifier
346        identifier.meta["is_table"] = is_table
347        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
348
349    def _depth(self) -> int:
350        # The columns themselves are a mapping, but we don't want to include those
351        return super()._depth() - 1
352
353    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
354        """
355        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
356
357        Args:
358            schema_type: the type we want to convert.
359            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
360
361        Returns:
362            The resulting expression type.
363        """
364        if schema_type not in self._type_mapping_cache:
365            dialect = dialect or self.dialect
366
367            try:
368                expression = exp.DataType.build(schema_type, dialect=dialect)
369                self._type_mapping_cache[schema_type] = expression
370            except AttributeError:
371                in_dialect = f" in dialect {dialect}" if dialect else ""
372                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
373
374        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
184    def __init__(
185        self,
186        schema: t.Optional[t.Dict] = None,
187        visible: t.Optional[t.Dict] = None,
188        dialect: DialectType = None,
189        normalize: bool = True,
190    ) -> None:
191        self.dialect = dialect
192        self.visible = visible or {}
193        self.normalize = normalize
194        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195
196        super().__init__(self._normalize(schema or {}))
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
198    @classmethod
199    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
200        return MappingSchema(
201            schema=mapping_schema.mapping,
202            visible=mapping_schema.visible,
203            dialect=mapping_schema.dialect,
204        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
206    def copy(self, **kwargs) -> MappingSchema:
207        return MappingSchema(
208            **{  # type: ignore
209                "schema": self.mapping.copy(),
210                "visible": self.visible.copy(),
211                "dialect": self.dialect,
212                **kwargs,
213            }
214        )
def ensure_schema( schema: Union[sqlglot.schema.Schema, Dict, NoneType], **kwargs: Any) -> sqlglot.schema.Schema:
377def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
378    if isinstance(schema, Schema):
379        return schema
380
381    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
384def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
385    if mapping is None:
386        return {}
387    elif isinstance(mapping, dict):
388        return mapping
389    elif isinstance(mapping, str):
390        col_name_type_strs = [x.strip() for x in mapping.split(",")]
391        return {
392            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
393            for name_type_str in col_name_type_strs
394        }
395    # Check if mapping looks like a DataFrame StructType
396    elif hasattr(mapping, "simpleString"):
397        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
398    elif isinstance(mapping, list):
399        return {x.strip(): None for x in mapping}
400
401    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
404def flatten_schema(
405    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
406) -> t.List[t.List[str]]:
407    tables = []
408    keys = keys or []
409
410    for k, v in schema.items():
411        if depth >= 2:
412            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
413        elif depth == 1:
414            tables.append(keys + [k])
415
416    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
419def nested_get(
420    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
421) -> t.Optional[t.Any]:
422    """
423    Get a value for a nested dictionary.
424
425    Args:
426        d: the dictionary to search.
427        *path: tuples of (name, key), where:
428            `key` is the key in the dictionary to get.
429            `name` is a string to use in the error if `key` isn't found.
430
431    Returns:
432        The value or None if it doesn't exist.
433    """
434    for name, key in path:
435        d = d.get(key)  # type: ignore
436        if d is None:
437            if raise_on_missing:
438                name = "table" if name == "this" else name
439                raise ValueError(f"Unknown {name}: {key}")
440            return None
441
442    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
445def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
446    """
447    In-place set a value for a nested dictionary
448
449    Example:
450        >>> nested_set({}, ["top_key", "second_key"], "value")
451        {'top_key': {'second_key': 'value'}}
452
453        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
454        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
455
456    Args:
457        d: dictionary to update.
458        keys: the keys that makeup the path to `value`.
459        value: the value to set in the dictionary for the given key path.
460
461    Returns:
462        The (possibly) updated dictionary.
463    """
464    if not keys:
465        return d
466
467    if len(keys) == 1:
468        d[keys[0]] = value
469        return d
470
471    subd = d
472    for key in keys[:-1]:
473        if key not in subd:
474            subd = subd.setdefault(key, {})
475        else:
476            subd = subd[key]
477
478    subd[keys[-1]] = value
479    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.