Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import in_trie, new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 19
 20TABLE_ARGS = ("this", "db", "catalog")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34    ) -> None:
 35        """
 36        Register or update a table. Some implementing classes may require column information to also be provided.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50    ) -> t.List[str]:
 51        """
 52        Get the column names for a table.
 53
 54        Args:
 55            table: the `Table` expression instance.
 56            only_visible: whether to include invisible columns.
 57            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 58
 59        Returns:
 60            The list of column names.
 61        """
 62
 63    @abc.abstractmethod
 64    def get_column_type(
 65        self,
 66        table: exp.Table | str,
 67        column: exp.Column,
 68        dialect: DialectType = None,
 69    ) -> exp.DataType:
 70        """
 71        Get the `sqlglot.exp.DataType` type of a column in the schema.
 72
 73        Args:
 74            table: the source table.
 75            column: the target column.
 76            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 77
 78        Returns:
 79            The resulting column type.
 80        """
 81
 82    @property
 83    @abc.abstractmethod
 84    def supported_table_args(self) -> t.Tuple[str, ...]:
 85        """
 86        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 87        """
 88
 89    @property
 90    def empty(self) -> bool:
 91        """Returns whether or not the schema is empty."""
 92        return True
 93
 94
 95class AbstractMappingSchema(t.Generic[T]):
 96    def __init__(
 97        self,
 98        mapping: t.Optional[t.Dict] = None,
 99    ) -> None:
100        self.mapping = mapping or {}
101        self.mapping_trie = new_trie(
102            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
103        )
104        self._supported_table_args: t.Tuple[str, ...] = tuple()
105
106    @property
107    def empty(self) -> bool:
108        return not self.mapping
109
110    def _depth(self) -> int:
111        return dict_depth(self.mapping)
112
113    @property
114    def supported_table_args(self) -> t.Tuple[str, ...]:
115        if not self._supported_table_args and self.mapping:
116            depth = self._depth()
117
118            if not depth:  # None
119                self._supported_table_args = tuple()
120            elif 1 <= depth <= 3:
121                self._supported_table_args = TABLE_ARGS[:depth]
122            else:
123                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
124
125        return self._supported_table_args
126
127    def table_parts(self, table: exp.Table) -> t.List[str]:
128        if isinstance(table.this, exp.ReadCSV):
129            return [table.this.name]
130        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
131
132    def find(
133        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
134    ) -> t.Optional[T]:
135        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
136        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
137
138        if value == 0:
139            return None
140
141        if value == 1:
142            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
143
144            if len(possibilities) == 1:
145                parts.extend(possibilities[0])
146            else:
147                message = ", ".join(".".join(parts) for parts in possibilities)
148                if raise_on_missing:
149                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
150                return None
151
152        return self.nested_get(parts, raise_on_missing=raise_on_missing)
153
154    def nested_get(
155        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
156    ) -> t.Optional[t.Any]:
157        return nested_get(
158            d or self.mapping,
159            *zip(self.supported_table_args, reversed(parts)),
160            raise_on_missing=raise_on_missing,
161        )
162
163
164class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
165    """
166    Schema based on a nested mapping.
167
168    Args:
169        schema: Mapping in one of the following forms:
170            1. {table: {col: type}}
171            2. {db: {table: {col: type}}}
172            3. {catalog: {db: {table: {col: type}}}}
173            4. None - Tables will be added later
174        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
175            are assumed to be visible. The nesting should mirror that of the schema:
176            1. {table: set(*cols)}}
177            2. {db: {table: set(*cols)}}}
178            3. {catalog: {db: {table: set(*cols)}}}}
179        dialect: The dialect to be used for custom type mappings & parsing string arguments.
180        normalize: Whether to normalize identifier names according to the given dialect or not.
181    """
182
183    def __init__(
184        self,
185        schema: t.Optional[t.Dict] = None,
186        visible: t.Optional[t.Dict] = None,
187        dialect: DialectType = None,
188        normalize: bool = True,
189    ) -> None:
190        self.dialect = dialect
191        self.visible = visible or {}
192        self.normalize = normalize
193        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
194
195        super().__init__(self._normalize(schema or {}))
196
197    @classmethod
198    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
199        return MappingSchema(
200            schema=mapping_schema.mapping,
201            visible=mapping_schema.visible,
202            dialect=mapping_schema.dialect,
203        )
204
205    def copy(self, **kwargs) -> MappingSchema:
206        return MappingSchema(
207            **{  # type: ignore
208                "schema": self.mapping.copy(),
209                "visible": self.visible.copy(),
210                "dialect": self.dialect,
211                **kwargs,
212            }
213        )
214
215    def add_table(
216        self,
217        table: exp.Table | str,
218        column_mapping: t.Optional[ColumnMapping] = None,
219        dialect: DialectType = None,
220    ) -> None:
221        """
222        Register or update a table. Updates are only performed if a new column mapping is provided.
223
224        Args:
225            table: the `Table` expression instance or string representing the table.
226            column_mapping: a column mapping that describes the structure of the table.
227            dialect: the SQL dialect that will be used to parse `table` if it's a string.
228        """
229        normalized_table = self._normalize_table(
230            self._ensure_table(table, dialect=dialect), dialect=dialect
231        )
232        normalized_column_mapping = {
233            self._normalize_name(key, dialect=dialect): value
234            for key, value in ensure_column_mapping(column_mapping).items()
235        }
236
237        schema = self.find(normalized_table, raise_on_missing=False)
238        if schema and not normalized_column_mapping:
239            return
240
241        parts = self.table_parts(normalized_table)
242
243        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
244        new_trie([parts], self.mapping_trie)
245
246    def column_names(
247        self,
248        table: exp.Table | str,
249        only_visible: bool = False,
250        dialect: DialectType = None,
251    ) -> t.List[str]:
252        normalized_table = self._normalize_table(
253            self._ensure_table(table, dialect=dialect), dialect=dialect
254        )
255
256        schema = self.find(normalized_table)
257        if schema is None:
258            return []
259
260        if not only_visible or not self.visible:
261            return list(schema)
262
263        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
264        return [col for col in schema if col in visible]
265
266    def get_column_type(
267        self,
268        table: exp.Table | str,
269        column: exp.Column,
270        dialect: DialectType = None,
271    ) -> exp.DataType:
272        normalized_table = self._normalize_table(
273            self._ensure_table(table, dialect=dialect), dialect=dialect
274        )
275        normalized_column_name = self._normalize_name(
276            column if isinstance(column, str) else column.this, dialect=dialect
277        )
278
279        table_schema = self.find(normalized_table, raise_on_missing=False)
280        if table_schema:
281            column_type = table_schema.get(normalized_column_name)
282
283            if isinstance(column_type, exp.DataType):
284                return column_type
285            elif isinstance(column_type, str):
286                return self._to_data_type(column_type.upper(), dialect=dialect)
287
288            raise SchemaError(f"Unknown column type '{column_type}'")
289
290        return exp.DataType.build("unknown")
291
292    def _normalize(self, schema: t.Dict) -> t.Dict:
293        """
294        Converts all identifiers in the schema into lowercase, unless they're quoted.
295
296        Args:
297            schema: the schema to normalize.
298
299        Returns:
300            The normalized schema mapping.
301        """
302        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
303
304        normalized_mapping: t.Dict = {}
305        for keys in flattened_schema:
306            columns = nested_get(schema, *zip(keys, keys))
307            assert columns is not None
308
309            normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
310            for column_name, column_type in columns.items():
311                nested_set(
312                    normalized_mapping,
313                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
314                    column_type,
315                )
316
317        return normalized_mapping
318
319    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
320        normalized_table = table.copy()
321
322        for arg in TABLE_ARGS:
323            value = normalized_table.args.get(arg)
324            if isinstance(value, (str, exp.Identifier)):
325                normalized_table.set(
326                    arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
327                )
328
329        return normalized_table
330
331    def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
332        dialect = dialect or self.dialect
333
334        try:
335            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
336        except ParseError:
337            return name if isinstance(name, str) else name.name
338
339        name = identifier.name
340
341        if not self.normalize or identifier.quoted:
342            return name
343
344        return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
345
346    def _depth(self) -> int:
347        # The columns themselves are a mapping, but we don't want to include those
348        return super()._depth() - 1
349
350    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
351        return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
352
353    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
354        """
355        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
356
357        Args:
358            schema_type: the type we want to convert.
359            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
360
361        Returns:
362            The resulting expression type.
363        """
364        if schema_type not in self._type_mapping_cache:
365            dialect = dialect or self.dialect
366
367            try:
368                expression = exp.DataType.build(schema_type, dialect=dialect)
369                self._type_mapping_cache[schema_type] = expression
370            except AttributeError:
371                in_dialect = f" in dialect {dialect}" if dialect else ""
372                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
373
374        return self._type_mapping_cache[schema_type]
375
376
377def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
378    if isinstance(schema, Schema):
379        return schema
380
381    return MappingSchema(schema, **kwargs)
382
383
384def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
385    if mapping is None:
386        return {}
387    elif isinstance(mapping, dict):
388        return mapping
389    elif isinstance(mapping, str):
390        col_name_type_strs = [x.strip() for x in mapping.split(",")]
391        return {
392            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
393            for name_type_str in col_name_type_strs
394        }
395    # Check if mapping looks like a DataFrame StructType
396    elif hasattr(mapping, "simpleString"):
397        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
398    elif isinstance(mapping, list):
399        return {x.strip(): None for x in mapping}
400
401    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
402
403
404def flatten_schema(
405    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
406) -> t.List[t.List[str]]:
407    tables = []
408    keys = keys or []
409
410    for k, v in schema.items():
411        if depth >= 2:
412            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
413        elif depth == 1:
414            tables.append(keys + [k])
415
416    return tables
417
418
419def nested_get(
420    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
421) -> t.Optional[t.Any]:
422    """
423    Get a value for a nested dictionary.
424
425    Args:
426        d: the dictionary to search.
427        *path: tuples of (name, key), where:
428            `key` is the key in the dictionary to get.
429            `name` is a string to use in the error if `key` isn't found.
430
431    Returns:
432        The value or None if it doesn't exist.
433    """
434    for name, key in path:
435        d = d.get(key)  # type: ignore
436        if d is None:
437            if raise_on_missing:
438                name = "table" if name == "this" else name
439                raise ValueError(f"Unknown {name}: {key}")
440            return None
441
442    return d
443
444
445def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
446    """
447    In-place set a value for a nested dictionary
448
449    Example:
450        >>> nested_set({}, ["top_key", "second_key"], "value")
451        {'top_key': {'second_key': 'value'}}
452
453        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
454        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
455
456    Args:
457        d: dictionary to update.
458        keys: the keys that makeup the path to `value`.
459        value: the value to set in the dictionary for the given key path.
460
461    Returns:
462        The (possibly) updated dictionary.
463    """
464    if not keys:
465        return d
466
467    if len(keys) == 1:
468        d[keys[0]] = value
469        return d
470
471    subd = d
472    for key in keys[:-1]:
473        if key not in subd:
474            subd = subd.setdefault(key, {})
475        else:
476            subd = subd[key]
477
478    subd[keys[-1]] = value
479    return d
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    dialect: DialectType
28
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38
39        Args:
40            table: the `Table` expression instance or string representing the table.
41            column_mapping: a column mapping that describes the structure of the table.
42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
43        """
44
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51    ) -> t.List[str]:
52        """
53        Get the column names for a table.
54
55        Args:
56            table: the `Table` expression instance.
57            only_visible: whether to include invisible columns.
58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
59
60        Returns:
61            The list of column names.
62        """
63
64    @abc.abstractmethod
65    def get_column_type(
66        self,
67        table: exp.Table | str,
68        column: exp.Column,
69        dialect: DialectType = None,
70    ) -> exp.DataType:
71        """
72        Get the `sqlglot.exp.DataType` type of a column in the schema.
73
74        Args:
75            table: the source table.
76            column: the target column.
77            dialect: the SQL dialect that will be used to parse `table` if it's a string.
78
79        Returns:
80            The resulting column type.
81        """
82
83    @property
84    @abc.abstractmethod
85    def supported_table_args(self) -> t.Tuple[str, ...]:
86        """
87        Table arguments this schema support, e.g. `("this", "db", "catalog")`
88        """
89
90    @property
91    def empty(self) -> bool:
92        """Returns whether or not the schema is empty."""
93        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35    ) -> None:
36        """
37        Register or update a table. Some implementing classes may require column information to also be provided.
38
39        Args:
40            table: the `Table` expression instance or string representing the table.
41            column_mapping: a column mapping that describes the structure of the table.
42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
43        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51    ) -> t.List[str]:
52        """
53        Get the column names for a table.
54
55        Args:
56            table: the `Table` expression instance.
57            only_visible: whether to include invisible columns.
58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
59
60        Returns:
61            The list of column names.
62        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
64    @abc.abstractmethod
65    def get_column_type(
66        self,
67        table: exp.Table | str,
68        column: exp.Column,
69        dialect: DialectType = None,
70    ) -> exp.DataType:
71        """
72        Get the `sqlglot.exp.DataType` type of a column in the schema.
73
74        Args:
75            table: the source table.
76            column: the target column.
77            dialect: the SQL dialect that will be used to parse `table` if it's a string.
78
79        Returns:
80            The resulting column type.
81        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
 96class AbstractMappingSchema(t.Generic[T]):
 97    def __init__(
 98        self,
 99        mapping: t.Optional[t.Dict] = None,
100    ) -> None:
101        self.mapping = mapping or {}
102        self.mapping_trie = new_trie(
103            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
104        )
105        self._supported_table_args: t.Tuple[str, ...] = tuple()
106
107    @property
108    def empty(self) -> bool:
109        return not self.mapping
110
111    def _depth(self) -> int:
112        return dict_depth(self.mapping)
113
114    @property
115    def supported_table_args(self) -> t.Tuple[str, ...]:
116        if not self._supported_table_args and self.mapping:
117            depth = self._depth()
118
119            if not depth:  # None
120                self._supported_table_args = tuple()
121            elif 1 <= depth <= 3:
122                self._supported_table_args = TABLE_ARGS[:depth]
123            else:
124                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
125
126        return self._supported_table_args
127
128    def table_parts(self, table: exp.Table) -> t.List[str]:
129        if isinstance(table.this, exp.ReadCSV):
130            return [table.this.name]
131        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
132
133    def find(
134        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
135    ) -> t.Optional[T]:
136        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
137        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138
139        if value == 0:
140            return None
141
142        if value == 1:
143            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144
145            if len(possibilities) == 1:
146                parts.extend(possibilities[0])
147            else:
148                message = ", ".join(".".join(parts) for parts in possibilities)
149                if raise_on_missing:
150                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
151                return None
152
153        return self.nested_get(parts, raise_on_missing=raise_on_missing)
154
155    def nested_get(
156        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
157    ) -> t.Optional[t.Any]:
158        return nested_get(
159            d or self.mapping,
160            *zip(self.supported_table_args, reversed(parts)),
161            raise_on_missing=raise_on_missing,
162        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
 97    def __init__(
 98        self,
 99        mapping: t.Optional[t.Dict] = None,
100    ) -> None:
101        self.mapping = mapping or {}
102        self.mapping_trie = new_trie(
103            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
104        )
105        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
128    def table_parts(self, table: exp.Table) -> t.List[str]:
129        if isinstance(table.this, exp.ReadCSV):
130            return [table.this.name]
131        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
133    def find(
134        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
135    ) -> t.Optional[T]:
136        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
137        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138
139        if value == 0:
140            return None
141
142        if value == 1:
143            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144
145            if len(possibilities) == 1:
146                parts.extend(possibilities[0])
147            else:
148                message = ", ".join(".".join(parts) for parts in possibilities)
149                if raise_on_missing:
150                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
151                return None
152
153        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
155    def nested_get(
156        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
157    ) -> t.Optional[t.Any]:
158        return nested_get(
159            d or self.mapping,
160            *zip(self.supported_table_args, reversed(parts)),
161            raise_on_missing=raise_on_missing,
162        )
165class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
166    """
167    Schema based on a nested mapping.
168
169    Args:
170        schema: Mapping in one of the following forms:
171            1. {table: {col: type}}
172            2. {db: {table: {col: type}}}
173            3. {catalog: {db: {table: {col: type}}}}
174            4. None - Tables will be added later
175        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
176            are assumed to be visible. The nesting should mirror that of the schema:
177            1. {table: set(*cols)}}
178            2. {db: {table: set(*cols)}}}
179            3. {catalog: {db: {table: set(*cols)}}}}
180        dialect: The dialect to be used for custom type mappings & parsing string arguments.
181        normalize: Whether to normalize identifier names according to the given dialect or not.
182    """
183
184    def __init__(
185        self,
186        schema: t.Optional[t.Dict] = None,
187        visible: t.Optional[t.Dict] = None,
188        dialect: DialectType = None,
189        normalize: bool = True,
190    ) -> None:
191        self.dialect = dialect
192        self.visible = visible or {}
193        self.normalize = normalize
194        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195
196        super().__init__(self._normalize(schema or {}))
197
198    @classmethod
199    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
200        return MappingSchema(
201            schema=mapping_schema.mapping,
202            visible=mapping_schema.visible,
203            dialect=mapping_schema.dialect,
204        )
205
206    def copy(self, **kwargs) -> MappingSchema:
207        return MappingSchema(
208            **{  # type: ignore
209                "schema": self.mapping.copy(),
210                "visible": self.visible.copy(),
211                "dialect": self.dialect,
212                **kwargs,
213            }
214        )
215
216    def add_table(
217        self,
218        table: exp.Table | str,
219        column_mapping: t.Optional[ColumnMapping] = None,
220        dialect: DialectType = None,
221    ) -> None:
222        """
223        Register or update a table. Updates are only performed if a new column mapping is provided.
224
225        Args:
226            table: the `Table` expression instance or string representing the table.
227            column_mapping: a column mapping that describes the structure of the table.
228            dialect: the SQL dialect that will be used to parse `table` if it's a string.
229        """
230        normalized_table = self._normalize_table(
231            self._ensure_table(table, dialect=dialect), dialect=dialect
232        )
233        normalized_column_mapping = {
234            self._normalize_name(key, dialect=dialect): value
235            for key, value in ensure_column_mapping(column_mapping).items()
236        }
237
238        schema = self.find(normalized_table, raise_on_missing=False)
239        if schema and not normalized_column_mapping:
240            return
241
242        parts = self.table_parts(normalized_table)
243
244        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
245        new_trie([parts], self.mapping_trie)
246
247    def column_names(
248        self,
249        table: exp.Table | str,
250        only_visible: bool = False,
251        dialect: DialectType = None,
252    ) -> t.List[str]:
253        normalized_table = self._normalize_table(
254            self._ensure_table(table, dialect=dialect), dialect=dialect
255        )
256
257        schema = self.find(normalized_table)
258        if schema is None:
259            return []
260
261        if not only_visible or not self.visible:
262            return list(schema)
263
264        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
265        return [col for col in schema if col in visible]
266
267    def get_column_type(
268        self,
269        table: exp.Table | str,
270        column: exp.Column,
271        dialect: DialectType = None,
272    ) -> exp.DataType:
273        normalized_table = self._normalize_table(
274            self._ensure_table(table, dialect=dialect), dialect=dialect
275        )
276        normalized_column_name = self._normalize_name(
277            column if isinstance(column, str) else column.this, dialect=dialect
278        )
279
280        table_schema = self.find(normalized_table, raise_on_missing=False)
281        if table_schema:
282            column_type = table_schema.get(normalized_column_name)
283
284            if isinstance(column_type, exp.DataType):
285                return column_type
286            elif isinstance(column_type, str):
287                return self._to_data_type(column_type.upper(), dialect=dialect)
288
289            raise SchemaError(f"Unknown column type '{column_type}'")
290
291        return exp.DataType.build("unknown")
292
293    def _normalize(self, schema: t.Dict) -> t.Dict:
294        """
295        Converts all identifiers in the schema into lowercase, unless they're quoted.
296
297        Args:
298            schema: the schema to normalize.
299
300        Returns:
301            The normalized schema mapping.
302        """
303        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
304
305        normalized_mapping: t.Dict = {}
306        for keys in flattened_schema:
307            columns = nested_get(schema, *zip(keys, keys))
308            assert columns is not None
309
310            normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
311            for column_name, column_type in columns.items():
312                nested_set(
313                    normalized_mapping,
314                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
315                    column_type,
316                )
317
318        return normalized_mapping
319
320    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
321        normalized_table = table.copy()
322
323        for arg in TABLE_ARGS:
324            value = normalized_table.args.get(arg)
325            if isinstance(value, (str, exp.Identifier)):
326                normalized_table.set(
327                    arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
328                )
329
330        return normalized_table
331
332    def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
333        dialect = dialect or self.dialect
334
335        try:
336            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
337        except ParseError:
338            return name if isinstance(name, str) else name.name
339
340        name = identifier.name
341
342        if not self.normalize or identifier.quoted:
343            return name
344
345        return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
346
347    def _depth(self) -> int:
348        # The columns themselves are a mapping, but we don't want to include those
349        return super()._depth() - 1
350
351    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
352        return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
353
354    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
355        """
356        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
357
358        Args:
359            schema_type: the type we want to convert.
360            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
361
362        Returns:
363            The resulting expression type.
364        """
365        if schema_type not in self._type_mapping_cache:
366            dialect = dialect or self.dialect
367
368            try:
369                expression = exp.DataType.build(schema_type, dialect=dialect)
370                self._type_mapping_cache[schema_type] = expression
371            except AttributeError:
372                in_dialect = f" in dialect {dialect}" if dialect else ""
373                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
374
375        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
184    def __init__(
185        self,
186        schema: t.Optional[t.Dict] = None,
187        visible: t.Optional[t.Dict] = None,
188        dialect: DialectType = None,
189        normalize: bool = True,
190    ) -> None:
191        self.dialect = dialect
192        self.visible = visible or {}
193        self.normalize = normalize
194        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195
196        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
198    @classmethod
199    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
200        return MappingSchema(
201            schema=mapping_schema.mapping,
202            visible=mapping_schema.visible,
203            dialect=mapping_schema.dialect,
204        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
206    def copy(self, **kwargs) -> MappingSchema:
207        return MappingSchema(
208            **{  # type: ignore
209                "schema": self.mapping.copy(),
210                "visible": self.visible.copy(),
211                "dialect": self.dialect,
212                **kwargs,
213            }
214        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
216    def add_table(
217        self,
218        table: exp.Table | str,
219        column_mapping: t.Optional[ColumnMapping] = None,
220        dialect: DialectType = None,
221    ) -> None:
222        """
223        Register or update a table. Updates are only performed if a new column mapping is provided.
224
225        Args:
226            table: the `Table` expression instance or string representing the table.
227            column_mapping: a column mapping that describes the structure of the table.
228            dialect: the SQL dialect that will be used to parse `table` if it's a string.
229        """
230        normalized_table = self._normalize_table(
231            self._ensure_table(table, dialect=dialect), dialect=dialect
232        )
233        normalized_column_mapping = {
234            self._normalize_name(key, dialect=dialect): value
235            for key, value in ensure_column_mapping(column_mapping).items()
236        }
237
238        schema = self.find(normalized_table, raise_on_missing=False)
239        if schema and not normalized_column_mapping:
240            return
241
242        parts = self.table_parts(normalized_table)
243
244        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
245        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
247    def column_names(
248        self,
249        table: exp.Table | str,
250        only_visible: bool = False,
251        dialect: DialectType = None,
252    ) -> t.List[str]:
253        normalized_table = self._normalize_table(
254            self._ensure_table(table, dialect=dialect), dialect=dialect
255        )
256
257        schema = self.find(normalized_table)
258        if schema is None:
259            return []
260
261        if not only_visible or not self.visible:
262            return list(schema)
263
264        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
265        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
267    def get_column_type(
268        self,
269        table: exp.Table | str,
270        column: exp.Column,
271        dialect: DialectType = None,
272    ) -> exp.DataType:
273        normalized_table = self._normalize_table(
274            self._ensure_table(table, dialect=dialect), dialect=dialect
275        )
276        normalized_column_name = self._normalize_name(
277            column if isinstance(column, str) else column.this, dialect=dialect
278        )
279
280        table_schema = self.find(normalized_table, raise_on_missing=False)
281        if table_schema:
282            column_type = table_schema.get(normalized_column_name)
283
284            if isinstance(column_type, exp.DataType):
285                return column_type
286            elif isinstance(column_type, str):
287                return self._to_data_type(column_type.upper(), dialect=dialect)
288
289            raise SchemaError(f"Unknown column type '{column_type}'")
290
291        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

def ensure_schema( schema: Union[sqlglot.schema.Schema, Dict, NoneType], **kwargs: Any) -> sqlglot.schema.Schema:
378def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
379    if isinstance(schema, Schema):
380        return schema
381
382    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
385def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
386    if mapping is None:
387        return {}
388    elif isinstance(mapping, dict):
389        return mapping
390    elif isinstance(mapping, str):
391        col_name_type_strs = [x.strip() for x in mapping.split(",")]
392        return {
393            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
394            for name_type_str in col_name_type_strs
395        }
396    # Check if mapping looks like a DataFrame StructType
397    elif hasattr(mapping, "simpleString"):
398        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
399    elif isinstance(mapping, list):
400        return {x.strip(): None for x in mapping}
401
402    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
405def flatten_schema(
406    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
407) -> t.List[t.List[str]]:
408    tables = []
409    keys = keys or []
410
411    for k, v in schema.items():
412        if depth >= 2:
413            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
414        elif depth == 1:
415            tables.append(keys + [k])
416
417    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
420def nested_get(
421    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
422) -> t.Optional[t.Any]:
423    """
424    Get a value for a nested dictionary.
425
426    Args:
427        d: the dictionary to search.
428        *path: tuples of (name, key), where:
429            `key` is the key in the dictionary to get.
430            `name` is a string to use in the error if `key` isn't found.
431
432    Returns:
433        The value or None if it doesn't exist.
434    """
435    for name, key in path:
436        d = d.get(key)  # type: ignore
437        if d is None:
438            if raise_on_missing:
439                name = "table" if name == "this" else name
440                raise ValueError(f"Unknown {name}: {key}")
441            return None
442
443    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
446def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
447    """
448    In-place set a value for a nested dictionary
449
450    Example:
451        >>> nested_set({}, ["top_key", "second_key"], "value")
452        {'top_key': {'second_key': 'value'}}
453
454        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
455        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
456
457    Args:
458        d: dictionary to update.
459        keys: the keys that makeup the path to `value`.
460        value: the value to set in the dictionary for the given key path.
461
462    Returns:
463        The (possibly) updated dictionary.
464    """
465    if not keys:
466        return d
467
468    if len(keys) == 1:
469        d[keys[0]] = value
470        return d
471
472    subd = d
473    for key in keys[:-1]:
474        if key not in subd:
475            subd = subd.setdefault(key, {})
476        else:
477            subd = subd[key]
478
479    subd[keys[-1]] = value
480    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.