Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import Dialect
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import in_trie, new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 19
 20TABLE_ARGS = ("this", "db", "catalog")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34    ) -> None:
 35        """
 36        Register or update a table. Some implementing classes may require column information to also be provided.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50    ) -> t.List[str]:
 51        """
 52        Get the column names for a table.
 53
 54        Args:
 55            table: the `Table` expression instance.
 56            only_visible: whether to include invisible columns.
 57            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 58
 59        Returns:
 60            The list of column names.
 61        """
 62
 63    @abc.abstractmethod
 64    def get_column_type(
 65        self,
 66        table: exp.Table | str,
 67        column: exp.Column,
 68        dialect: DialectType = None,
 69    ) -> exp.DataType:
 70        """
 71        Get the `sqlglot.exp.DataType` type of a column in the schema.
 72
 73        Args:
 74            table: the source table.
 75            column: the target column.
 76            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 77
 78        Returns:
 79            The resulting column type.
 80        """
 81
 82    @property
 83    @abc.abstractmethod
 84    def supported_table_args(self) -> t.Tuple[str, ...]:
 85        """
 86        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 87        """
 88
 89    @property
 90    def empty(self) -> bool:
 91        """Returns whether or not the schema is empty."""
 92        return True
 93
 94
 95class AbstractMappingSchema(t.Generic[T]):
 96    def __init__(
 97        self,
 98        mapping: t.Optional[t.Dict] = None,
 99    ) -> None:
100        self.mapping = mapping or {}
101        self.mapping_trie = new_trie(
102            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
103        )
104        self._supported_table_args: t.Tuple[str, ...] = tuple()
105
106    @property
107    def empty(self) -> bool:
108        return not self.mapping
109
110    def _depth(self) -> int:
111        return dict_depth(self.mapping)
112
113    @property
114    def supported_table_args(self) -> t.Tuple[str, ...]:
115        if not self._supported_table_args and self.mapping:
116            depth = self._depth()
117
118            if not depth:  # None
119                self._supported_table_args = tuple()
120            elif 1 <= depth <= 3:
121                self._supported_table_args = TABLE_ARGS[:depth]
122            else:
123                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
124
125        return self._supported_table_args
126
127    def table_parts(self, table: exp.Table) -> t.List[str]:
128        if isinstance(table.this, exp.ReadCSV):
129            return [table.this.name]
130        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
131
132    def find(
133        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
134    ) -> t.Optional[T]:
135        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
136        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
137
138        if value == 0:
139            return None
140
141        if value == 1:
142            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
143
144            if len(possibilities) == 1:
145                parts.extend(possibilities[0])
146            else:
147                message = ", ".join(".".join(parts) for parts in possibilities)
148                if raise_on_missing:
149                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
150                return None
151
152        return self.nested_get(parts, raise_on_missing=raise_on_missing)
153
154    def nested_get(
155        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
156    ) -> t.Optional[t.Any]:
157        return nested_get(
158            d or self.mapping,
159            *zip(self.supported_table_args, reversed(parts)),
160            raise_on_missing=raise_on_missing,
161        )
162
163
164class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
165    """
166    Schema based on a nested mapping.
167
168    Args:
169        schema: Mapping in one of the following forms:
170            1. {table: {col: type}}
171            2. {db: {table: {col: type}}}
172            3. {catalog: {db: {table: {col: type}}}}
173            4. None - Tables will be added later
174        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
175            are assumed to be visible. The nesting should mirror that of the schema:
176            1. {table: set(*cols)}}
177            2. {db: {table: set(*cols)}}}
178            3. {catalog: {db: {table: set(*cols)}}}}
179        dialect: The dialect to be used for custom type mappings & parsing string arguments.
180        normalize: Whether to normalize identifier names according to the given dialect or not.
181    """
182
183    def __init__(
184        self,
185        schema: t.Optional[t.Dict] = None,
186        visible: t.Optional[t.Dict] = None,
187        dialect: DialectType = None,
188        normalize: bool = True,
189    ) -> None:
190        self.dialect = dialect
191        self.visible = visible or {}
192        self.normalize = normalize
193        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
194
195        super().__init__(self._normalize(schema or {}))
196
197    @classmethod
198    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
199        return MappingSchema(
200            schema=mapping_schema.mapping,
201            visible=mapping_schema.visible,
202            dialect=mapping_schema.dialect,
203        )
204
205    def copy(self, **kwargs) -> MappingSchema:
206        return MappingSchema(
207            **{  # type: ignore
208                "schema": self.mapping.copy(),
209                "visible": self.visible.copy(),
210                "dialect": self.dialect,
211                **kwargs,
212            }
213        )
214
215    def add_table(
216        self,
217        table: exp.Table | str,
218        column_mapping: t.Optional[ColumnMapping] = None,
219        dialect: DialectType = None,
220    ) -> None:
221        """
222        Register or update a table. Updates are only performed if a new column mapping is provided.
223
224        Args:
225            table: the `Table` expression instance or string representing the table.
226            column_mapping: a column mapping that describes the structure of the table.
227            dialect: the SQL dialect that will be used to parse `table` if it's a string.
228        """
229        normalized_table = self._normalize_table(
230            self._ensure_table(table, dialect=dialect), dialect=dialect
231        )
232        normalized_column_mapping = {
233            self._normalize_name(key, dialect=dialect): value
234            for key, value in ensure_column_mapping(column_mapping).items()
235        }
236
237        schema = self.find(normalized_table, raise_on_missing=False)
238        if schema and not normalized_column_mapping:
239            return
240
241        parts = self.table_parts(normalized_table)
242
243        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
244        new_trie([parts], self.mapping_trie)
245
246    def column_names(
247        self,
248        table: exp.Table | str,
249        only_visible: bool = False,
250        dialect: DialectType = None,
251    ) -> t.List[str]:
252        normalized_table = self._normalize_table(
253            self._ensure_table(table, dialect=dialect), dialect=dialect
254        )
255
256        schema = self.find(normalized_table)
257        if schema is None:
258            return []
259
260        if not only_visible or not self.visible:
261            return list(schema)
262
263        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
264        return [col for col in schema if col in visible]
265
266    def get_column_type(
267        self,
268        table: exp.Table | str,
269        column: exp.Column,
270        dialect: DialectType = None,
271    ) -> exp.DataType:
272        normalized_table = self._normalize_table(
273            self._ensure_table(table, dialect=dialect), dialect=dialect
274        )
275        normalized_column_name = self._normalize_name(
276            column if isinstance(column, str) else column.this, dialect=dialect
277        )
278
279        table_schema = self.find(normalized_table, raise_on_missing=False)
280        if table_schema:
281            column_type = table_schema.get(normalized_column_name)
282
283            if isinstance(column_type, exp.DataType):
284                return column_type
285            elif isinstance(column_type, str):
286                return self._to_data_type(column_type.upper(), dialect=dialect)
287
288        return exp.DataType.build("unknown")
289
290    def _normalize(self, schema: t.Dict) -> t.Dict:
291        """
292        Normalizes all identifiers in the schema.
293
294        Args:
295            schema: the schema to normalize.
296
297        Returns:
298            The normalized schema mapping.
299        """
300        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
301
302        normalized_mapping: t.Dict = {}
303        for keys in flattened_schema:
304            columns = nested_get(schema, *zip(keys, keys))
305            assert columns is not None
306
307            normalized_keys = [
308                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
309            ]
310            for column_name, column_type in columns.items():
311                nested_set(
312                    normalized_mapping,
313                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
314                    column_type,
315                )
316
317        return normalized_mapping
318
319    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
320        normalized_table = table.copy()
321
322        for arg in TABLE_ARGS:
323            value = normalized_table.args.get(arg)
324            if isinstance(value, (str, exp.Identifier)):
325                normalized_table.set(
326                    arg,
327                    exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
328                )
329
330        return normalized_table
331
332    def _normalize_name(
333        self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
334    ) -> str:
335        dialect = dialect or self.dialect
336
337        try:
338            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
339        except ParseError:
340            return name if isinstance(name, str) else name.name
341
342        name = identifier.name
343        if not self.normalize:
344            return name
345
346        # This can be useful for normalize_identifier
347        identifier.meta["is_table"] = is_table
348        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
349
350    def _depth(self) -> int:
351        # The columns themselves are a mapping, but we don't want to include those
352        return super()._depth() - 1
353
354    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
355        return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
356
357    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
358        """
359        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
360
361        Args:
362            schema_type: the type we want to convert.
363            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
364
365        Returns:
366            The resulting expression type.
367        """
368        if schema_type not in self._type_mapping_cache:
369            dialect = dialect or self.dialect
370
371            try:
372                expression = exp.DataType.build(schema_type, dialect=dialect)
373                self._type_mapping_cache[schema_type] = expression
374            except AttributeError:
375                in_dialect = f" in dialect {dialect}" if dialect else ""
376                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
377
378        return self._type_mapping_cache[schema_type]
379
380
381def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
382    if isinstance(schema, Schema):
383        return schema
384
385    return MappingSchema(schema, **kwargs)
386
387
388def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
389    if mapping is None:
390        return {}
391    elif isinstance(mapping, dict):
392        return mapping
393    elif isinstance(mapping, str):
394        col_name_type_strs = [x.strip() for x in mapping.split(",")]
395        return {
396            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
397            for name_type_str in col_name_type_strs
398        }
399    # Check if mapping looks like a DataFrame StructType
400    elif hasattr(mapping, "simpleString"):
401        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
402    elif isinstance(mapping, list):
403        return {x.strip(): None for x in mapping}
404
405    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
406
407
408def flatten_schema(
409    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
410) -> t.List[t.List[str]]:
411    tables = []
412    keys = keys or []
413
414    for k, v in schema.items():
415        if depth >= 2:
416            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
417        elif depth == 1:
418            tables.append(keys + [k])
419
420    return tables
421
422
423def nested_get(
424    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
425) -> t.Optional[t.Any]:
426    """
427    Get a value for a nested dictionary.
428
429    Args:
430        d: the dictionary to search.
431        *path: tuples of (name, key), where:
432            `key` is the key in the dictionary to get.
433            `name` is a string to use in the error if `key` isn't found.
434
435    Returns:
436        The value or None if it doesn't exist.
437    """
438    for name, key in path:
439        d = d.get(key)  # type: ignore
440        if d is None:
441            if raise_on_missing:
442                name = "table" if name == "this" else name
443                raise ValueError(f"Unknown {name}: {key}")
444            return None
445
446    return d
447
448
449def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
450    """
451    In-place set a value for a nested dictionary
452
453    Example:
454        >>> nested_set({}, ["top_key", "second_key"], "value")
455        {'top_key': {'second_key': 'value'}}
456
457        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
458        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
459
460    Args:
461        d: dictionary to update.
462        keys: the keys that makeup the path to `value`.
463        value: the value to set in the dictionary for the given key path.
464
465    Returns:
466        The (possibly) updated dictionary.
467    """
468    if not keys:
469        return d
470
471    if len(keys) == 1:
472        d[keys[0]] = value
473        return d
474
475    subd = d
476    for key in keys[:-1]:
477        if key not in subd:
478            subd = subd.setdefault(key, {})
479        else:
480            subd = subd[key]
481
482    subd[keys[-1]] = value
483    return d
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    dialect: DialectType
28
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38
39        Args:
40            table: the `Table` expression instance or string representing the table.
41            column_mapping: a column mapping that describes the structure of the table.
42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
43        """
44
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51    ) -> t.List[str]:
52        """
53        Get the column names for a table.
54
55        Args:
56            table: the `Table` expression instance.
57            only_visible: whether to include invisible columns.
58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
59
60        Returns:
61            The list of column names.
62        """
63
64    @abc.abstractmethod
65    def get_column_type(
66        self,
67        table: exp.Table | str,
68        column: exp.Column,
69        dialect: DialectType = None,
70    ) -> exp.DataType:
71        """
72        Get the `sqlglot.exp.DataType` type of a column in the schema.
73
74        Args:
75            table: the source table.
76            column: the target column.
77            dialect: the SQL dialect that will be used to parse `table` if it's a string.
78
79        Returns:
80            The resulting column type.
81        """
82
83    @property
84    @abc.abstractmethod
85    def supported_table_args(self) -> t.Tuple[str, ...]:
86        """
87        Table arguments this schema support, e.g. `("this", "db", "catalog")`
88        """
89
90    @property
91    def empty(self) -> bool:
92        """Returns whether or not the schema is empty."""
93        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38
39        Args:
40            table: the `Table` expression instance or string representing the table.
41            column_mapping: a column mapping that describes the structure of the table.
42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
43        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51    ) -> t.List[str]:
52        """
53        Get the column names for a table.
54
55        Args:
56            table: the `Table` expression instance.
57            only_visible: whether to include invisible columns.
58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
59
60        Returns:
61            The list of column names.
62        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
64    @abc.abstractmethod
65    def get_column_type(
66        self,
67        table: exp.Table | str,
68        column: exp.Column,
69        dialect: DialectType = None,
70    ) -> exp.DataType:
71        """
72        Get the `sqlglot.exp.DataType` type of a column in the schema.
73
74        Args:
75            table: the source table.
76            column: the target column.
77            dialect: the SQL dialect that will be used to parse `table` if it's a string.
78
79        Returns:
80            The resulting column type.
81        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
 96class AbstractMappingSchema(t.Generic[T]):
 97    def __init__(
 98        self,
 99        mapping: t.Optional[t.Dict] = None,
100    ) -> None:
101        self.mapping = mapping or {}
102        self.mapping_trie = new_trie(
103            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
104        )
105        self._supported_table_args: t.Tuple[str, ...] = tuple()
106
107    @property
108    def empty(self) -> bool:
109        return not self.mapping
110
111    def _depth(self) -> int:
112        return dict_depth(self.mapping)
113
114    @property
115    def supported_table_args(self) -> t.Tuple[str, ...]:
116        if not self._supported_table_args and self.mapping:
117            depth = self._depth()
118
119            if not depth:  # None
120                self._supported_table_args = tuple()
121            elif 1 <= depth <= 3:
122                self._supported_table_args = TABLE_ARGS[:depth]
123            else:
124                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
125
126        return self._supported_table_args
127
128    def table_parts(self, table: exp.Table) -> t.List[str]:
129        if isinstance(table.this, exp.ReadCSV):
130            return [table.this.name]
131        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
132
133    def find(
134        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
135    ) -> t.Optional[T]:
136        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
137        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138
139        if value == 0:
140            return None
141
142        if value == 1:
143            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144
145            if len(possibilities) == 1:
146                parts.extend(possibilities[0])
147            else:
148                message = ", ".join(".".join(parts) for parts in possibilities)
149                if raise_on_missing:
150                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
151                return None
152
153        return self.nested_get(parts, raise_on_missing=raise_on_missing)
154
155    def nested_get(
156        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
157    ) -> t.Optional[t.Any]:
158        return nested_get(
159            d or self.mapping,
160            *zip(self.supported_table_args, reversed(parts)),
161            raise_on_missing=raise_on_missing,
162        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
 97    def __init__(
 98        self,
 99        mapping: t.Optional[t.Dict] = None,
100    ) -> None:
101        self.mapping = mapping or {}
102        self.mapping_trie = new_trie(
103            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
104        )
105        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
128    def table_parts(self, table: exp.Table) -> t.List[str]:
129        if isinstance(table.this, exp.ReadCSV):
130            return [table.this.name]
131        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
133    def find(
134        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
135    ) -> t.Optional[T]:
136        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
137        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138
139        if value == 0:
140            return None
141
142        if value == 1:
143            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144
145            if len(possibilities) == 1:
146                parts.extend(possibilities[0])
147            else:
148                message = ", ".join(".".join(parts) for parts in possibilities)
149                if raise_on_missing:
150                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
151                return None
152
153        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
155    def nested_get(
156        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
157    ) -> t.Optional[t.Any]:
158        return nested_get(
159            d or self.mapping,
160            *zip(self.supported_table_args, reversed(parts)),
161            raise_on_missing=raise_on_missing,
162        )
165class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
166    """
167    Schema based on a nested mapping.
168
169    Args:
170        schema: Mapping in one of the following forms:
171            1. {table: {col: type}}
172            2. {db: {table: {col: type}}}
173            3. {catalog: {db: {table: {col: type}}}}
174            4. None - Tables will be added later
175        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
176            are assumed to be visible. The nesting should mirror that of the schema:
177            1. {table: set(*cols)}}
178            2. {db: {table: set(*cols)}}}
179            3. {catalog: {db: {table: set(*cols)}}}}
180        dialect: The dialect to be used for custom type mappings & parsing string arguments.
181        normalize: Whether to normalize identifier names according to the given dialect or not.
182    """
183
184    def __init__(
185        self,
186        schema: t.Optional[t.Dict] = None,
187        visible: t.Optional[t.Dict] = None,
188        dialect: DialectType = None,
189        normalize: bool = True,
190    ) -> None:
191        self.dialect = dialect
192        self.visible = visible or {}
193        self.normalize = normalize
194        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195
196        super().__init__(self._normalize(schema or {}))
197
198    @classmethod
199    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
200        return MappingSchema(
201            schema=mapping_schema.mapping,
202            visible=mapping_schema.visible,
203            dialect=mapping_schema.dialect,
204        )
205
206    def copy(self, **kwargs) -> MappingSchema:
207        return MappingSchema(
208            **{  # type: ignore
209                "schema": self.mapping.copy(),
210                "visible": self.visible.copy(),
211                "dialect": self.dialect,
212                **kwargs,
213            }
214        )
215
216    def add_table(
217        self,
218        table: exp.Table | str,
219        column_mapping: t.Optional[ColumnMapping] = None,
220        dialect: DialectType = None,
221    ) -> None:
222        """
223        Register or update a table. Updates are only performed if a new column mapping is provided.
224
225        Args:
226            table: the `Table` expression instance or string representing the table.
227            column_mapping: a column mapping that describes the structure of the table.
228            dialect: the SQL dialect that will be used to parse `table` if it's a string.
229        """
230        normalized_table = self._normalize_table(
231            self._ensure_table(table, dialect=dialect), dialect=dialect
232        )
233        normalized_column_mapping = {
234            self._normalize_name(key, dialect=dialect): value
235            for key, value in ensure_column_mapping(column_mapping).items()
236        }
237
238        schema = self.find(normalized_table, raise_on_missing=False)
239        if schema and not normalized_column_mapping:
240            return
241
242        parts = self.table_parts(normalized_table)
243
244        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
245        new_trie([parts], self.mapping_trie)
246
247    def column_names(
248        self,
249        table: exp.Table | str,
250        only_visible: bool = False,
251        dialect: DialectType = None,
252    ) -> t.List[str]:
253        normalized_table = self._normalize_table(
254            self._ensure_table(table, dialect=dialect), dialect=dialect
255        )
256
257        schema = self.find(normalized_table)
258        if schema is None:
259            return []
260
261        if not only_visible or not self.visible:
262            return list(schema)
263
264        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
265        return [col for col in schema if col in visible]
266
267    def get_column_type(
268        self,
269        table: exp.Table | str,
270        column: exp.Column,
271        dialect: DialectType = None,
272    ) -> exp.DataType:
273        normalized_table = self._normalize_table(
274            self._ensure_table(table, dialect=dialect), dialect=dialect
275        )
276        normalized_column_name = self._normalize_name(
277            column if isinstance(column, str) else column.this, dialect=dialect
278        )
279
280        table_schema = self.find(normalized_table, raise_on_missing=False)
281        if table_schema:
282            column_type = table_schema.get(normalized_column_name)
283
284            if isinstance(column_type, exp.DataType):
285                return column_type
286            elif isinstance(column_type, str):
287                return self._to_data_type(column_type.upper(), dialect=dialect)
288
289        return exp.DataType.build("unknown")
290
291    def _normalize(self, schema: t.Dict) -> t.Dict:
292        """
293        Normalizes all identifiers in the schema.
294
295        Args:
296            schema: the schema to normalize.
297
298        Returns:
299            The normalized schema mapping.
300        """
301        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
302
303        normalized_mapping: t.Dict = {}
304        for keys in flattened_schema:
305            columns = nested_get(schema, *zip(keys, keys))
306            assert columns is not None
307
308            normalized_keys = [
309                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
310            ]
311            for column_name, column_type in columns.items():
312                nested_set(
313                    normalized_mapping,
314                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
315                    column_type,
316                )
317
318        return normalized_mapping
319
320    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
321        normalized_table = table.copy()
322
323        for arg in TABLE_ARGS:
324            value = normalized_table.args.get(arg)
325            if isinstance(value, (str, exp.Identifier)):
326                normalized_table.set(
327                    arg,
328                    exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
329                )
330
331        return normalized_table
332
333    def _normalize_name(
334        self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
335    ) -> str:
336        dialect = dialect or self.dialect
337
338        try:
339            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
340        except ParseError:
341            return name if isinstance(name, str) else name.name
342
343        name = identifier.name
344        if not self.normalize:
345            return name
346
347        # This can be useful for normalize_identifier
348        identifier.meta["is_table"] = is_table
349        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
350
351    def _depth(self) -> int:
352        # The columns themselves are a mapping, but we don't want to include those
353        return super()._depth() - 1
354
355    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
356        return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
357
358    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
359        """
360        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
361
362        Args:
363            schema_type: the type we want to convert.
364            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
365
366        Returns:
367            The resulting expression type.
368        """
369        if schema_type not in self._type_mapping_cache:
370            dialect = dialect or self.dialect
371
372            try:
373                expression = exp.DataType.build(schema_type, dialect=dialect)
374                self._type_mapping_cache[schema_type] = expression
375            except AttributeError:
376                in_dialect = f" in dialect {dialect}" if dialect else ""
377                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
378
379        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
184    def __init__(
185        self,
186        schema: t.Optional[t.Dict] = None,
187        visible: t.Optional[t.Dict] = None,
188        dialect: DialectType = None,
189        normalize: bool = True,
190    ) -> None:
191        self.dialect = dialect
192        self.visible = visible or {}
193        self.normalize = normalize
194        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195
196        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
198    @classmethod
199    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
200        return MappingSchema(
201            schema=mapping_schema.mapping,
202            visible=mapping_schema.visible,
203            dialect=mapping_schema.dialect,
204        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
206    def copy(self, **kwargs) -> MappingSchema:
207        return MappingSchema(
208            **{  # type: ignore
209                "schema": self.mapping.copy(),
210                "visible": self.visible.copy(),
211                "dialect": self.dialect,
212                **kwargs,
213            }
214        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
216    def add_table(
217        self,
218        table: exp.Table | str,
219        column_mapping: t.Optional[ColumnMapping] = None,
220        dialect: DialectType = None,
221    ) -> None:
222        """
223        Register or update a table. Updates are only performed if a new column mapping is provided.
224
225        Args:
226            table: the `Table` expression instance or string representing the table.
227            column_mapping: a column mapping that describes the structure of the table.
228            dialect: the SQL dialect that will be used to parse `table` if it's a string.
229        """
230        normalized_table = self._normalize_table(
231            self._ensure_table(table, dialect=dialect), dialect=dialect
232        )
233        normalized_column_mapping = {
234            self._normalize_name(key, dialect=dialect): value
235            for key, value in ensure_column_mapping(column_mapping).items()
236        }
237
238        schema = self.find(normalized_table, raise_on_missing=False)
239        if schema and not normalized_column_mapping:
240            return
241
242        parts = self.table_parts(normalized_table)
243
244        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
245        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
247    def column_names(
248        self,
249        table: exp.Table | str,
250        only_visible: bool = False,
251        dialect: DialectType = None,
252    ) -> t.List[str]:
253        normalized_table = self._normalize_table(
254            self._ensure_table(table, dialect=dialect), dialect=dialect
255        )
256
257        schema = self.find(normalized_table)
258        if schema is None:
259            return []
260
261        if not only_visible or not self.visible:
262            return list(schema)
263
264        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
265        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
267    def get_column_type(
268        self,
269        table: exp.Table | str,
270        column: exp.Column,
271        dialect: DialectType = None,
272    ) -> exp.DataType:
273        normalized_table = self._normalize_table(
274            self._ensure_table(table, dialect=dialect), dialect=dialect
275        )
276        normalized_column_name = self._normalize_name(
277            column if isinstance(column, str) else column.this, dialect=dialect
278        )
279
280        table_schema = self.find(normalized_table, raise_on_missing=False)
281        if table_schema:
282            column_type = table_schema.get(normalized_column_name)
283
284            if isinstance(column_type, exp.DataType):
285                return column_type
286            elif isinstance(column_type, str):
287                return self._to_data_type(column_type.upper(), dialect=dialect)
288
289        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

def ensure_schema( schema: Union[sqlglot.schema.Schema, Dict, NoneType], **kwargs: Any) -> sqlglot.schema.Schema:
382def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
383    if isinstance(schema, Schema):
384        return schema
385
386    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
389def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
390    if mapping is None:
391        return {}
392    elif isinstance(mapping, dict):
393        return mapping
394    elif isinstance(mapping, str):
395        col_name_type_strs = [x.strip() for x in mapping.split(",")]
396        return {
397            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
398            for name_type_str in col_name_type_strs
399        }
400    # Check if mapping looks like a DataFrame StructType
401    elif hasattr(mapping, "simpleString"):
402        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
403    elif isinstance(mapping, list):
404        return {x.strip(): None for x in mapping}
405
406    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
409def flatten_schema(
410    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
411) -> t.List[t.List[str]]:
412    tables = []
413    keys = keys or []
414
415    for k, v in schema.items():
416        if depth >= 2:
417            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
418        elif depth == 1:
419            tables.append(keys + [k])
420
421    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
424def nested_get(
425    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
426) -> t.Optional[t.Any]:
427    """
428    Get a value for a nested dictionary.
429
430    Args:
431        d: the dictionary to search.
432        *path: tuples of (name, key), where:
433            `key` is the key in the dictionary to get.
434            `name` is a string to use in the error if `key` isn't found.
435
436    Returns:
437        The value or None if it doesn't exist.
438    """
439    for name, key in path:
440        d = d.get(key)  # type: ignore
441        if d is None:
442            if raise_on_missing:
443                name = "table" if name == "this" else name
444                raise ValueError(f"Unknown {name}: {key}")
445            return None
446
447    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
450def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
451    """
452    In-place set a value for a nested dictionary
453
454    Example:
455        >>> nested_set({}, ["top_key", "second_key"], "value")
456        {'top_key': {'second_key': 'value'}}
457
458        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
459        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
460
461    Args:
462        d: dictionary to update.
463        keys: the keys that makeup the path to `value`.
464        value: the value to set in the dictionary for the given key path.
465
466    Returns:
467        The (possibly) updated dictionary.
468    """
469    if not keys:
470        return d
471
472    if len(keys) == 1:
473        d[keys[0]] = value
474        return d
475
476    subd = d
477    for key in keys[:-1]:
478        if key not in subd:
479            subd = subd.setdefault(key, {})
480        else:
481            subd = subd[key]
482
483    subd[keys[-1]] = value
484    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.