Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import Dialect
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import TrieResult, in_trie, new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 19
 20TABLE_ARGS = ("this", "db", "catalog")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34        normalize: t.Optional[bool] = None,
 35        match_depth: bool = True,
 36    ) -> None:
 37        """
 38        Register or update a table. Some implementing classes may require column information to also be provided.
 39        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 40
 41        Args:
 42            table: the `Table` expression instance or string representing the table.
 43            column_mapping: a column mapping that describes the structure of the table.
 44            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 45            normalize: whether to normalize identifiers according to the dialect of interest.
 46            match_depth: whether to enforce that the table must match the schema's depth or not.
 47        """
 48
 49    @abc.abstractmethod
 50    def column_names(
 51        self,
 52        table: exp.Table | str,
 53        only_visible: bool = False,
 54        dialect: DialectType = None,
 55        normalize: t.Optional[bool] = None,
 56    ) -> t.List[str]:
 57        """
 58        Get the column names for a table.
 59
 60        Args:
 61            table: the `Table` expression instance.
 62            only_visible: whether to include invisible columns.
 63            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 64            normalize: whether to normalize identifiers according to the dialect of interest.
 65
 66        Returns:
 67            The list of column names.
 68        """
 69
 70    @abc.abstractmethod
 71    def get_column_type(
 72        self,
 73        table: exp.Table | str,
 74        column: exp.Column,
 75        dialect: DialectType = None,
 76        normalize: t.Optional[bool] = None,
 77    ) -> exp.DataType:
 78        """
 79        Get the `sqlglot.exp.DataType` type of a column in the schema.
 80
 81        Args:
 82            table: the source table.
 83            column: the target column.
 84            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 85            normalize: whether to normalize identifiers according to the dialect of interest.
 86
 87        Returns:
 88            The resulting column type.
 89        """
 90
 91    @property
 92    @abc.abstractmethod
 93    def supported_table_args(self) -> t.Tuple[str, ...]:
 94        """
 95        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 96        """
 97
 98    @property
 99    def empty(self) -> bool:
100        """Returns whether or not the schema is empty."""
101        return True
102
103
104class AbstractMappingSchema(t.Generic[T]):
105    def __init__(
106        self,
107        mapping: t.Optional[t.Dict] = None,
108    ) -> None:
109        self.mapping = mapping or {}
110        self.mapping_trie = new_trie(
111            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
112        )
113        self._supported_table_args: t.Tuple[str, ...] = tuple()
114
115    @property
116    def empty(self) -> bool:
117        return not self.mapping
118
119    def depth(self) -> int:
120        return dict_depth(self.mapping)
121
122    @property
123    def supported_table_args(self) -> t.Tuple[str, ...]:
124        if not self._supported_table_args and self.mapping:
125            depth = self.depth()
126
127            if not depth:  # None
128                self._supported_table_args = tuple()
129            elif 1 <= depth <= 3:
130                self._supported_table_args = TABLE_ARGS[:depth]
131            else:
132                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
133
134        return self._supported_table_args
135
136    def table_parts(self, table: exp.Table) -> t.List[str]:
137        if isinstance(table.this, exp.ReadCSV):
138            return [table.this.name]
139        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
140
141    def find(
142        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
143    ) -> t.Optional[T]:
144        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
145        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
146
147        if value == TrieResult.FAILED:
148            return None
149
150        if value == TrieResult.PREFIX:
151            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
152
153            if len(possibilities) == 1:
154                parts.extend(possibilities[0])
155            else:
156                message = ", ".join(".".join(parts) for parts in possibilities)
157                if raise_on_missing:
158                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
159                return None
160
161        return self.nested_get(parts, raise_on_missing=raise_on_missing)
162
163    def nested_get(
164        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
165    ) -> t.Optional[t.Any]:
166        return nested_get(
167            d or self.mapping,
168            *zip(self.supported_table_args, reversed(parts)),
169            raise_on_missing=raise_on_missing,
170        )
171
172
173class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
174    """
175    Schema based on a nested mapping.
176
177    Args:
178        schema: Mapping in one of the following forms:
179            1. {table: {col: type}}
180            2. {db: {table: {col: type}}}
181            3. {catalog: {db: {table: {col: type}}}}
182            4. None - Tables will be added later
183        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
184            are assumed to be visible. The nesting should mirror that of the schema:
185            1. {table: set(*cols)}}
186            2. {db: {table: set(*cols)}}}
187            3. {catalog: {db: {table: set(*cols)}}}}
188        dialect: The dialect to be used for custom type mappings & parsing string arguments.
189        normalize: Whether to normalize identifier names according to the given dialect or not.
190    """
191
192    def __init__(
193        self,
194        schema: t.Optional[t.Dict] = None,
195        visible: t.Optional[t.Dict] = None,
196        dialect: DialectType = None,
197        normalize: bool = True,
198    ) -> None:
199        self.dialect = dialect
200        self.visible = visible or {}
201        self.normalize = normalize
202        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
203        self._depth = 0
204
205        super().__init__(self._normalize(schema or {}))
206
207    @classmethod
208    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
209        return MappingSchema(
210            schema=mapping_schema.mapping,
211            visible=mapping_schema.visible,
212            dialect=mapping_schema.dialect,
213            normalize=mapping_schema.normalize,
214        )
215
216    def copy(self, **kwargs) -> MappingSchema:
217        return MappingSchema(
218            **{  # type: ignore
219                "schema": self.mapping.copy(),
220                "visible": self.visible.copy(),
221                "dialect": self.dialect,
222                "normalize": self.normalize,
223                **kwargs,
224            }
225        )
226
227    def add_table(
228        self,
229        table: exp.Table | str,
230        column_mapping: t.Optional[ColumnMapping] = None,
231        dialect: DialectType = None,
232        normalize: t.Optional[bool] = None,
233        match_depth: bool = True,
234    ) -> None:
235        """
236        Register or update a table. Updates are only performed if a new column mapping is provided.
237        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
238
239        Args:
240            table: the `Table` expression instance or string representing the table.
241            column_mapping: a column mapping that describes the structure of the table.
242            dialect: the SQL dialect that will be used to parse `table` if it's a string.
243            normalize: whether to normalize identifiers according to the dialect of interest.
244            match_depth: whether to enforce that the table must match the schema's depth or not.
245        """
246        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
247
248        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
249            raise SchemaError(
250                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
251                f"schema's nesting level: {self.depth()}."
252            )
253
254        normalized_column_mapping = {
255            self._normalize_name(key, dialect=dialect, normalize=normalize): value
256            for key, value in ensure_column_mapping(column_mapping).items()
257        }
258
259        schema = self.find(normalized_table, raise_on_missing=False)
260        if schema and not normalized_column_mapping:
261            return
262
263        parts = self.table_parts(normalized_table)
264
265        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
266        new_trie([parts], self.mapping_trie)
267
268    def column_names(
269        self,
270        table: exp.Table | str,
271        only_visible: bool = False,
272        dialect: DialectType = None,
273        normalize: t.Optional[bool] = None,
274    ) -> t.List[str]:
275        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
276
277        schema = self.find(normalized_table)
278        if schema is None:
279            return []
280
281        if not only_visible or not self.visible:
282            return list(schema)
283
284        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
285        return [col for col in schema if col in visible]
286
287    def get_column_type(
288        self,
289        table: exp.Table | str,
290        column: exp.Column,
291        dialect: DialectType = None,
292        normalize: t.Optional[bool] = None,
293    ) -> exp.DataType:
294        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
295
296        normalized_column_name = self._normalize_name(
297            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
298        )
299
300        table_schema = self.find(normalized_table, raise_on_missing=False)
301        if table_schema:
302            column_type = table_schema.get(normalized_column_name)
303
304            if isinstance(column_type, exp.DataType):
305                return column_type
306            elif isinstance(column_type, str):
307                return self._to_data_type(column_type.upper(), dialect=dialect)
308
309        return exp.DataType.build("unknown")
310
311    def _normalize(self, schema: t.Dict) -> t.Dict:
312        """
313        Normalizes all identifiers in the schema.
314
315        Args:
316            schema: the schema to normalize.
317
318        Returns:
319            The normalized schema mapping.
320        """
321        normalized_mapping: t.Dict = {}
322        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
323
324        for keys in flattened_schema:
325            columns = nested_get(schema, *zip(keys, keys))
326
327            if not isinstance(columns, dict):
328                raise SchemaError(
329                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
330                )
331
332            normalized_keys = [
333                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
334            ]
335            for column_name, column_type in columns.items():
336                nested_set(
337                    normalized_mapping,
338                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
339                    column_type,
340                )
341
342        return normalized_mapping
343
344    def _normalize_table(
345        self,
346        table: exp.Table | str,
347        dialect: DialectType = None,
348        normalize: t.Optional[bool] = None,
349    ) -> exp.Table:
350        normalized_table = exp.maybe_parse(
351            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
352        )
353
354        for arg in TABLE_ARGS:
355            value = normalized_table.args.get(arg)
356            if isinstance(value, (str, exp.Identifier)):
357                normalized_table.set(
358                    arg,
359                    exp.to_identifier(
360                        self._normalize_name(
361                            value, dialect=dialect, is_table=True, normalize=normalize
362                        )
363                    ),
364                )
365
366        return normalized_table
367
368    def _normalize_name(
369        self,
370        name: str | exp.Identifier,
371        dialect: DialectType = None,
372        is_table: bool = False,
373        normalize: t.Optional[bool] = None,
374    ) -> str:
375        return normalize_name(
376            name,
377            dialect=dialect or self.dialect,
378            is_table=is_table,
379            normalize=self.normalize if normalize is None else normalize,
380        )
381
382    def depth(self) -> int:
383        if not self.empty and not self._depth:
384            # The columns themselves are a mapping, but we don't want to include those
385            self._depth = super().depth() - 1
386        return self._depth
387
388    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
389        """
390        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
391
392        Args:
393            schema_type: the type we want to convert.
394            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
395
396        Returns:
397            The resulting expression type.
398        """
399        if schema_type not in self._type_mapping_cache:
400            dialect = dialect or self.dialect
401
402            try:
403                expression = exp.DataType.build(schema_type, dialect=dialect)
404                self._type_mapping_cache[schema_type] = expression
405            except AttributeError:
406                in_dialect = f" in dialect {dialect}" if dialect else ""
407                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
408
409        return self._type_mapping_cache[schema_type]
410
411
412def normalize_name(
413    name: str | exp.Identifier,
414    dialect: DialectType = None,
415    is_table: bool = False,
416    normalize: t.Optional[bool] = True,
417) -> str:
418    try:
419        identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
420    except ParseError:
421        return name if isinstance(name, str) else name.name
422
423    name = identifier.name
424    if not normalize:
425        return name
426
427    # This can be useful for normalize_identifier
428    identifier.meta["is_table"] = is_table
429    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
430
431
432def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
433    if isinstance(schema, Schema):
434        return schema
435
436    return MappingSchema(schema, **kwargs)
437
438
439def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
440    if mapping is None:
441        return {}
442    elif isinstance(mapping, dict):
443        return mapping
444    elif isinstance(mapping, str):
445        col_name_type_strs = [x.strip() for x in mapping.split(",")]
446        return {
447            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
448            for name_type_str in col_name_type_strs
449        }
450    # Check if mapping looks like a DataFrame StructType
451    elif hasattr(mapping, "simpleString"):
452        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
453    elif isinstance(mapping, list):
454        return {x.strip(): None for x in mapping}
455
456    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
457
458
459def flatten_schema(
460    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
461) -> t.List[t.List[str]]:
462    tables = []
463    keys = keys or []
464
465    for k, v in schema.items():
466        if depth >= 2:
467            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
468        elif depth == 1:
469            tables.append(keys + [k])
470
471    return tables
472
473
474def nested_get(
475    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
476) -> t.Optional[t.Any]:
477    """
478    Get a value for a nested dictionary.
479
480    Args:
481        d: the dictionary to search.
482        *path: tuples of (name, key), where:
483            `key` is the key in the dictionary to get.
484            `name` is a string to use in the error if `key` isn't found.
485
486    Returns:
487        The value or None if it doesn't exist.
488    """
489    for name, key in path:
490        d = d.get(key)  # type: ignore
491        if d is None:
492            if raise_on_missing:
493                name = "table" if name == "this" else name
494                raise ValueError(f"Unknown {name}: {key}")
495            return None
496
497    return d
498
499
500def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
501    """
502    In-place set a value for a nested dictionary
503
504    Example:
505        >>> nested_set({}, ["top_key", "second_key"], "value")
506        {'top_key': {'second_key': 'value'}}
507
508        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
509        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
510
511    Args:
512        d: dictionary to update.
513        keys: the keys that makeup the path to `value`.
514        value: the value to set in the dictionary for the given key path.
515
516    Returns:
517        The (possibly) updated dictionary.
518    """
519    if not keys:
520        return d
521
522    if len(keys) == 1:
523        d[keys[0]] = value
524        return d
525
526    subd = d
527    for key in keys[:-1]:
528        if key not in subd:
529            subd = subd.setdefault(key, {})
530        else:
531            subd = subd[key]
532
533    subd[keys[-1]] = value
534    return d
TABLE_ARGS = ('this', 'db', 'catalog')
class Schema(abc.ABC):
 24class Schema(abc.ABC):
 25    """Abstract base class for database schemas"""
 26
 27    dialect: DialectType
 28
 29    @abc.abstractmethod
 30    def add_table(
 31        self,
 32        table: exp.Table | str,
 33        column_mapping: t.Optional[ColumnMapping] = None,
 34        dialect: DialectType = None,
 35        normalize: t.Optional[bool] = None,
 36        match_depth: bool = True,
 37    ) -> None:
 38        """
 39        Register or update a table. Some implementing classes may require column information to also be provided.
 40        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 41
 42        Args:
 43            table: the `Table` expression instance or string representing the table.
 44            column_mapping: a column mapping that describes the structure of the table.
 45            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 46            normalize: whether to normalize identifiers according to the dialect of interest.
 47            match_depth: whether to enforce that the table must match the schema's depth or not.
 48        """
 49
 50    @abc.abstractmethod
 51    def column_names(
 52        self,
 53        table: exp.Table | str,
 54        only_visible: bool = False,
 55        dialect: DialectType = None,
 56        normalize: t.Optional[bool] = None,
 57    ) -> t.List[str]:
 58        """
 59        Get the column names for a table.
 60
 61        Args:
 62            table: the `Table` expression instance.
 63            only_visible: whether to include invisible columns.
 64            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 65            normalize: whether to normalize identifiers according to the dialect of interest.
 66
 67        Returns:
 68            The list of column names.
 69        """
 70
 71    @abc.abstractmethod
 72    def get_column_type(
 73        self,
 74        table: exp.Table | str,
 75        column: exp.Column,
 76        dialect: DialectType = None,
 77        normalize: t.Optional[bool] = None,
 78    ) -> exp.DataType:
 79        """
 80        Get the `sqlglot.exp.DataType` type of a column in the schema.
 81
 82        Args:
 83            table: the source table.
 84            column: the target column.
 85            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 86            normalize: whether to normalize identifiers according to the dialect of interest.
 87
 88        Returns:
 89            The resulting column type.
 90        """
 91
 92    @property
 93    @abc.abstractmethod
 94    def supported_table_args(self) -> t.Tuple[str, ...]:
 95        """
 96        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 97        """
 98
 99    @property
100    def empty(self) -> bool:
101        """Returns whether or not the schema is empty."""
102        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35        normalize: t.Optional[bool] = None,
36        match_depth: bool = True,
37    ) -> None:
38        """
39        Register or update a table. Some implementing classes may require column information to also be provided.
40        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
41
42        Args:
43            table: the `Table` expression instance or string representing the table.
44            column_mapping: a column mapping that describes the structure of the table.
45            dialect: the SQL dialect that will be used to parse `table` if it's a string.
46            normalize: whether to normalize identifiers according to the dialect of interest.
47            match_depth: whether to enforce that the table must match the schema's depth or not.
48        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
50    @abc.abstractmethod
51    def column_names(
52        self,
53        table: exp.Table | str,
54        only_visible: bool = False,
55        dialect: DialectType = None,
56        normalize: t.Optional[bool] = None,
57    ) -> t.List[str]:
58        """
59        Get the column names for a table.
60
61        Args:
62            table: the `Table` expression instance.
63            only_visible: whether to include invisible columns.
64            dialect: the SQL dialect that will be used to parse `table` if it's a string.
65            normalize: whether to normalize identifiers according to the dialect of interest.
66
67        Returns:
68            The list of column names.
69        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
71    @abc.abstractmethod
72    def get_column_type(
73        self,
74        table: exp.Table | str,
75        column: exp.Column,
76        dialect: DialectType = None,
77        normalize: t.Optional[bool] = None,
78    ) -> exp.DataType:
79        """
80        Get the `sqlglot.exp.DataType` type of a column in the schema.
81
82        Args:
83            table: the source table.
84            column: the target column.
85            dialect: the SQL dialect that will be used to parse `table` if it's a string.
86            normalize: whether to normalize identifiers according to the dialect of interest.
87
88        Returns:
89            The resulting column type.
90        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
105class AbstractMappingSchema(t.Generic[T]):
106    def __init__(
107        self,
108        mapping: t.Optional[t.Dict] = None,
109    ) -> None:
110        self.mapping = mapping or {}
111        self.mapping_trie = new_trie(
112            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
113        )
114        self._supported_table_args: t.Tuple[str, ...] = tuple()
115
116    @property
117    def empty(self) -> bool:
118        return not self.mapping
119
120    def depth(self) -> int:
121        return dict_depth(self.mapping)
122
123    @property
124    def supported_table_args(self) -> t.Tuple[str, ...]:
125        if not self._supported_table_args and self.mapping:
126            depth = self.depth()
127
128            if not depth:  # None
129                self._supported_table_args = tuple()
130            elif 1 <= depth <= 3:
131                self._supported_table_args = TABLE_ARGS[:depth]
132            else:
133                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
134
135        return self._supported_table_args
136
137    def table_parts(self, table: exp.Table) -> t.List[str]:
138        if isinstance(table.this, exp.ReadCSV):
139            return [table.this.name]
140        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
141
142    def find(
143        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
144    ) -> t.Optional[T]:
145        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
146        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
147
148        if value == TrieResult.FAILED:
149            return None
150
151        if value == TrieResult.PREFIX:
152            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
153
154            if len(possibilities) == 1:
155                parts.extend(possibilities[0])
156            else:
157                message = ", ".join(".".join(parts) for parts in possibilities)
158                if raise_on_missing:
159                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
160                return None
161
162        return self.nested_get(parts, raise_on_missing=raise_on_missing)
163
164    def nested_get(
165        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
166    ) -> t.Optional[t.Any]:
167        return nested_get(
168            d or self.mapping,
169            *zip(self.supported_table_args, reversed(parts)),
170            raise_on_missing=raise_on_missing,
171        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
106    def __init__(
107        self,
108        mapping: t.Optional[t.Dict] = None,
109    ) -> None:
110        self.mapping = mapping or {}
111        self.mapping_trie = new_trie(
112            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
113        )
114        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
def depth(self) -> int:
120    def depth(self) -> int:
121        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
137    def table_parts(self, table: exp.Table) -> t.List[str]:
138        if isinstance(table.this, exp.ReadCSV):
139            return [table.this.name]
140        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
142    def find(
143        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
144    ) -> t.Optional[T]:
145        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
146        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
147
148        if value == TrieResult.FAILED:
149            return None
150
151        if value == TrieResult.PREFIX:
152            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
153
154            if len(possibilities) == 1:
155                parts.extend(possibilities[0])
156            else:
157                message = ", ".join(".".join(parts) for parts in possibilities)
158                if raise_on_missing:
159                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
160                return None
161
162        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
164    def nested_get(
165        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
166    ) -> t.Optional[t.Any]:
167        return nested_get(
168            d or self.mapping,
169            *zip(self.supported_table_args, reversed(parts)),
170            raise_on_missing=raise_on_missing,
171        )
174class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
175    """
176    Schema based on a nested mapping.
177
178    Args:
179        schema: Mapping in one of the following forms:
180            1. {table: {col: type}}
181            2. {db: {table: {col: type}}}
182            3. {catalog: {db: {table: {col: type}}}}
183            4. None - Tables will be added later
184        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
185            are assumed to be visible. The nesting should mirror that of the schema:
186            1. {table: set(*cols)}}
187            2. {db: {table: set(*cols)}}}
188            3. {catalog: {db: {table: set(*cols)}}}}
189        dialect: The dialect to be used for custom type mappings & parsing string arguments.
190        normalize: Whether to normalize identifier names according to the given dialect or not.
191    """
192
193    def __init__(
194        self,
195        schema: t.Optional[t.Dict] = None,
196        visible: t.Optional[t.Dict] = None,
197        dialect: DialectType = None,
198        normalize: bool = True,
199    ) -> None:
200        self.dialect = dialect
201        self.visible = visible or {}
202        self.normalize = normalize
203        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
204        self._depth = 0
205
206        super().__init__(self._normalize(schema or {}))
207
208    @classmethod
209    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
210        return MappingSchema(
211            schema=mapping_schema.mapping,
212            visible=mapping_schema.visible,
213            dialect=mapping_schema.dialect,
214            normalize=mapping_schema.normalize,
215        )
216
217    def copy(self, **kwargs) -> MappingSchema:
218        return MappingSchema(
219            **{  # type: ignore
220                "schema": self.mapping.copy(),
221                "visible": self.visible.copy(),
222                "dialect": self.dialect,
223                "normalize": self.normalize,
224                **kwargs,
225            }
226        )
227
228    def add_table(
229        self,
230        table: exp.Table | str,
231        column_mapping: t.Optional[ColumnMapping] = None,
232        dialect: DialectType = None,
233        normalize: t.Optional[bool] = None,
234        match_depth: bool = True,
235    ) -> None:
236        """
237        Register or update a table. Updates are only performed if a new column mapping is provided.
238        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
239
240        Args:
241            table: the `Table` expression instance or string representing the table.
242            column_mapping: a column mapping that describes the structure of the table.
243            dialect: the SQL dialect that will be used to parse `table` if it's a string.
244            normalize: whether to normalize identifiers according to the dialect of interest.
245            match_depth: whether to enforce that the table must match the schema's depth or not.
246        """
247        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
248
249        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
250            raise SchemaError(
251                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
252                f"schema's nesting level: {self.depth()}."
253            )
254
255        normalized_column_mapping = {
256            self._normalize_name(key, dialect=dialect, normalize=normalize): value
257            for key, value in ensure_column_mapping(column_mapping).items()
258        }
259
260        schema = self.find(normalized_table, raise_on_missing=False)
261        if schema and not normalized_column_mapping:
262            return
263
264        parts = self.table_parts(normalized_table)
265
266        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
267        new_trie([parts], self.mapping_trie)
268
269    def column_names(
270        self,
271        table: exp.Table | str,
272        only_visible: bool = False,
273        dialect: DialectType = None,
274        normalize: t.Optional[bool] = None,
275    ) -> t.List[str]:
276        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
277
278        schema = self.find(normalized_table)
279        if schema is None:
280            return []
281
282        if not only_visible or not self.visible:
283            return list(schema)
284
285        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
286        return [col for col in schema if col in visible]
287
288    def get_column_type(
289        self,
290        table: exp.Table | str,
291        column: exp.Column,
292        dialect: DialectType = None,
293        normalize: t.Optional[bool] = None,
294    ) -> exp.DataType:
295        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
296
297        normalized_column_name = self._normalize_name(
298            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
299        )
300
301        table_schema = self.find(normalized_table, raise_on_missing=False)
302        if table_schema:
303            column_type = table_schema.get(normalized_column_name)
304
305            if isinstance(column_type, exp.DataType):
306                return column_type
307            elif isinstance(column_type, str):
308                return self._to_data_type(column_type.upper(), dialect=dialect)
309
310        return exp.DataType.build("unknown")
311
312    def _normalize(self, schema: t.Dict) -> t.Dict:
313        """
314        Normalizes all identifiers in the schema.
315
316        Args:
317            schema: the schema to normalize.
318
319        Returns:
320            The normalized schema mapping.
321        """
322        normalized_mapping: t.Dict = {}
323        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
324
325        for keys in flattened_schema:
326            columns = nested_get(schema, *zip(keys, keys))
327
328            if not isinstance(columns, dict):
329                raise SchemaError(
330                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
331                )
332
333            normalized_keys = [
334                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
335            ]
336            for column_name, column_type in columns.items():
337                nested_set(
338                    normalized_mapping,
339                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
340                    column_type,
341                )
342
343        return normalized_mapping
344
345    def _normalize_table(
346        self,
347        table: exp.Table | str,
348        dialect: DialectType = None,
349        normalize: t.Optional[bool] = None,
350    ) -> exp.Table:
351        normalized_table = exp.maybe_parse(
352            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
353        )
354
355        for arg in TABLE_ARGS:
356            value = normalized_table.args.get(arg)
357            if isinstance(value, (str, exp.Identifier)):
358                normalized_table.set(
359                    arg,
360                    exp.to_identifier(
361                        self._normalize_name(
362                            value, dialect=dialect, is_table=True, normalize=normalize
363                        )
364                    ),
365                )
366
367        return normalized_table
368
369    def _normalize_name(
370        self,
371        name: str | exp.Identifier,
372        dialect: DialectType = None,
373        is_table: bool = False,
374        normalize: t.Optional[bool] = None,
375    ) -> str:
376        return normalize_name(
377            name,
378            dialect=dialect or self.dialect,
379            is_table=is_table,
380            normalize=self.normalize if normalize is None else normalize,
381        )
382
383    def depth(self) -> int:
384        if not self.empty and not self._depth:
385            # The columns themselves are a mapping, but we don't want to include those
386            self._depth = super().depth() - 1
387        return self._depth
388
389    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
390        """
391        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
392
393        Args:
394            schema_type: the type we want to convert.
395            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
396
397        Returns:
398            The resulting expression type.
399        """
400        if schema_type not in self._type_mapping_cache:
401            dialect = dialect or self.dialect
402
403            try:
404                expression = exp.DataType.build(schema_type, dialect=dialect)
405                self._type_mapping_cache[schema_type] = expression
406            except AttributeError:
407                in_dialect = f" in dialect {dialect}" if dialect else ""
408                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
409
410        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
193    def __init__(
194        self,
195        schema: t.Optional[t.Dict] = None,
196        visible: t.Optional[t.Dict] = None,
197        dialect: DialectType = None,
198        normalize: bool = True,
199    ) -> None:
200        self.dialect = dialect
201        self.visible = visible or {}
202        self.normalize = normalize
203        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
204        self._depth = 0
205
206        super().__init__(self._normalize(schema or {}))
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
208    @classmethod
209    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
210        return MappingSchema(
211            schema=mapping_schema.mapping,
212            visible=mapping_schema.visible,
213            dialect=mapping_schema.dialect,
214            normalize=mapping_schema.normalize,
215        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
217    def copy(self, **kwargs) -> MappingSchema:
218        return MappingSchema(
219            **{  # type: ignore
220                "schema": self.mapping.copy(),
221                "visible": self.visible.copy(),
222                "dialect": self.dialect,
223                "normalize": self.normalize,
224                **kwargs,
225            }
226        )
def depth(self) -> int:
383    def depth(self) -> int:
384        if not self.empty and not self._depth:
385            # The columns themselves are a mapping, but we don't want to include those
386            self._depth = super().depth() - 1
387        return self._depth
def normalize_name( name: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> str:
413def normalize_name(
414    name: str | exp.Identifier,
415    dialect: DialectType = None,
416    is_table: bool = False,
417    normalize: t.Optional[bool] = True,
418) -> str:
419    try:
420        identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
421    except ParseError:
422        return name if isinstance(name, str) else name.name
423
424    name = identifier.name
425    if not normalize:
426        return name
427
428    # This can be useful for normalize_identifier
429    identifier.meta["is_table"] = is_table
430    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def ensure_schema( schema: Union[sqlglot.schema.Schema, Dict, NoneType], **kwargs: Any) -> sqlglot.schema.Schema:
433def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
434    if isinstance(schema, Schema):
435        return schema
436
437    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
440def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
441    if mapping is None:
442        return {}
443    elif isinstance(mapping, dict):
444        return mapping
445    elif isinstance(mapping, str):
446        col_name_type_strs = [x.strip() for x in mapping.split(",")]
447        return {
448            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
449            for name_type_str in col_name_type_strs
450        }
451    # Check if mapping looks like a DataFrame StructType
452    elif hasattr(mapping, "simpleString"):
453        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
454    elif isinstance(mapping, list):
455        return {x.strip(): None for x in mapping}
456
457    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
460def flatten_schema(
461    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
462) -> t.List[t.List[str]]:
463    tables = []
464    keys = keys or []
465
466    for k, v in schema.items():
467        if depth >= 2:
468            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
469        elif depth == 1:
470            tables.append(keys + [k])
471
472    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
475def nested_get(
476    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
477) -> t.Optional[t.Any]:
478    """
479    Get a value for a nested dictionary.
480
481    Args:
482        d: the dictionary to search.
483        *path: tuples of (name, key), where:
484            `key` is the key in the dictionary to get.
485            `name` is a string to use in the error if `key` isn't found.
486
487    Returns:
488        The value or None if it doesn't exist.
489    """
490    for name, key in path:
491        d = d.get(key)  # type: ignore
492        if d is None:
493            if raise_on_missing:
494                name = "table" if name == "this" else name
495                raise ValueError(f"Unknown {name}: {key}")
496            return None
497
498    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
501def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
502    """
503    In-place set a value for a nested dictionary
504
505    Example:
506        >>> nested_set({}, ["top_key", "second_key"], "value")
507        {'top_key': {'second_key': 'value'}}
508
509        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
510        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
511
512    Args:
513        d: dictionary to update.
514        keys: the keys that makeup the path to `value`.
515        value: the value to set in the dictionary for the given key path.
516
517    Returns:
518        The (possibly) updated dictionary.
519    """
520    if not keys:
521        return d
522
523    if len(keys) == 1:
524        d[keys[0]] = value
525        return d
526
527    subd = d
528    for key in keys[:-1]:
529        if key not in subd:
530            subd = subd.setdefault(key, {})
531        else:
532            subd = subd[key]
533
534    subd[keys[-1]] = value
535    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.