Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.errors import ParseError, SchemaError
 10from sqlglot.helper import dict_depth
 11from sqlglot.trie import in_trie, new_trie
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dataframe.sql.types import StructType
 15    from sqlglot.dialects.dialect import DialectType
 16
 17    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 18
 19TABLE_ARGS = ("this", "db", "catalog")
 20
 21
 22class Schema(abc.ABC):
 23    """Abstract base class for database schemas"""
 24
 25    @abc.abstractmethod
 26    def add_table(
 27        self,
 28        table: exp.Table | str,
 29        column_mapping: t.Optional[ColumnMapping] = None,
 30        dialect: DialectType = None,
 31    ) -> None:
 32        """
 33        Register or update a table. Some implementing classes may require column information to also be provided.
 34
 35        Args:
 36            table: the `Table` expression instance or string representing the table.
 37            column_mapping: a column mapping that describes the structure of the table.
 38            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 39        """
 40
 41    @abc.abstractmethod
 42    def column_names(
 43        self,
 44        table: exp.Table | str,
 45        only_visible: bool = False,
 46        dialect: DialectType = None,
 47    ) -> t.List[str]:
 48        """
 49        Get the column names for a table.
 50
 51        Args:
 52            table: the `Table` expression instance.
 53            only_visible: whether to include invisible columns.
 54            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 55
 56        Returns:
 57            The list of column names.
 58        """
 59
 60    @abc.abstractmethod
 61    def get_column_type(
 62        self,
 63        table: exp.Table | str,
 64        column: exp.Column,
 65        dialect: DialectType = None,
 66    ) -> exp.DataType:
 67        """
 68        Get the `sqlglot.exp.DataType` type of a column in the schema.
 69
 70        Args:
 71            table: the source table.
 72            column: the target column.
 73            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 74
 75        Returns:
 76            The resulting column type.
 77        """
 78
 79    @property
 80    @abc.abstractmethod
 81    def supported_table_args(self) -> t.Tuple[str, ...]:
 82        """
 83        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 84        """
 85
 86    @property
 87    def empty(self) -> bool:
 88        """Returns whether or not the schema is empty."""
 89        return True
 90
 91
 92class AbstractMappingSchema(t.Generic[T]):
 93    def __init__(
 94        self,
 95        mapping: t.Optional[t.Dict] = None,
 96    ) -> None:
 97        self.mapping = mapping or {}
 98        self.mapping_trie = new_trie(
 99            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
100        )
101        self._supported_table_args: t.Tuple[str, ...] = tuple()
102
103    @property
104    def empty(self) -> bool:
105        return not self.mapping
106
107    def _depth(self) -> int:
108        return dict_depth(self.mapping)
109
110    @property
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        if not self._supported_table_args and self.mapping:
113            depth = self._depth()
114
115            if not depth:  # None
116                self._supported_table_args = tuple()
117            elif 1 <= depth <= 3:
118                self._supported_table_args = TABLE_ARGS[:depth]
119            else:
120                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
121
122        return self._supported_table_args
123
124    def table_parts(self, table: exp.Table) -> t.List[str]:
125        if isinstance(table.this, exp.ReadCSV):
126            return [table.this.name]
127        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
128
129    def find(
130        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
131    ) -> t.Optional[T]:
132        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
133        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
134
135        if value == 0:
136            return None
137
138        if value == 1:
139            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
140
141            if len(possibilities) == 1:
142                parts.extend(possibilities[0])
143            else:
144                message = ", ".join(".".join(parts) for parts in possibilities)
145                if raise_on_missing:
146                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
147                return None
148
149        return self.nested_get(parts, raise_on_missing=raise_on_missing)
150
151    def nested_get(
152        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
153    ) -> t.Optional[t.Any]:
154        return nested_get(
155            d or self.mapping,
156            *zip(self.supported_table_args, reversed(parts)),
157            raise_on_missing=raise_on_missing,
158        )
159
160
161class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
162    """
163    Schema based on a nested mapping.
164
165    Args:
166        schema: Mapping in one of the following forms:
167            1. {table: {col: type}}
168            2. {db: {table: {col: type}}}
169            3. {catalog: {db: {table: {col: type}}}}
170            4. None - Tables will be added later
171        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
172            are assumed to be visible. The nesting should mirror that of the schema:
173            1. {table: set(*cols)}}
174            2. {db: {table: set(*cols)}}}
175            3. {catalog: {db: {table: set(*cols)}}}}
176        dialect: The dialect to be used for custom type mappings & parsing string arguments.
177    """
178
179    def __init__(
180        self,
181        schema: t.Optional[t.Dict] = None,
182        visible: t.Optional[t.Dict] = None,
183        dialect: DialectType = None,
184    ) -> None:
185        self.dialect = dialect
186        self.visible = visible or {}
187        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
188
189        super().__init__(self._normalize(schema or {}))
190
191    @classmethod
192    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
193        return MappingSchema(
194            schema=mapping_schema.mapping,
195            visible=mapping_schema.visible,
196            dialect=mapping_schema.dialect,
197        )
198
199    def copy(self, **kwargs) -> MappingSchema:
200        return MappingSchema(
201            **{  # type: ignore
202                "schema": self.mapping.copy(),
203                "visible": self.visible.copy(),
204                "dialect": self.dialect,
205                **kwargs,
206            }
207        )
208
209    def add_table(
210        self,
211        table: exp.Table | str,
212        column_mapping: t.Optional[ColumnMapping] = None,
213        dialect: DialectType = None,
214    ) -> None:
215        """
216        Register or update a table. Updates are only performed if a new column mapping is provided.
217
218        Args:
219            table: the `Table` expression instance or string representing the table.
220            column_mapping: a column mapping that describes the structure of the table.
221            dialect: the SQL dialect that will be used to parse `table` if it's a string.
222        """
223        normalized_table = self._normalize_table(
224            self._ensure_table(table, dialect=dialect), dialect=dialect
225        )
226        normalized_column_mapping = {
227            self._normalize_name(key, dialect=dialect): value
228            for key, value in ensure_column_mapping(column_mapping).items()
229        }
230
231        schema = self.find(normalized_table, raise_on_missing=False)
232        if schema and not normalized_column_mapping:
233            return
234
235        parts = self.table_parts(normalized_table)
236
237        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
238        new_trie([parts], self.mapping_trie)
239
240    def column_names(
241        self,
242        table: exp.Table | str,
243        only_visible: bool = False,
244        dialect: DialectType = None,
245    ) -> t.List[str]:
246        normalized_table = self._normalize_table(
247            self._ensure_table(table, dialect=dialect), dialect=dialect
248        )
249
250        schema = self.find(normalized_table)
251        if schema is None:
252            return []
253
254        if not only_visible or not self.visible:
255            return list(schema)
256
257        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
258        return [col for col in schema if col in visible]
259
260    def get_column_type(
261        self,
262        table: exp.Table | str,
263        column: exp.Column,
264        dialect: DialectType = None,
265    ) -> exp.DataType:
266        normalized_table = self._normalize_table(
267            self._ensure_table(table, dialect=dialect), dialect=dialect
268        )
269        normalized_column_name = self._normalize_name(
270            column if isinstance(column, str) else column.this, dialect=dialect
271        )
272
273        table_schema = self.find(normalized_table, raise_on_missing=False)
274        if table_schema:
275            column_type = table_schema.get(normalized_column_name)
276
277            if isinstance(column_type, exp.DataType):
278                return column_type
279            elif isinstance(column_type, str):
280                return self._to_data_type(column_type.upper(), dialect=dialect)
281
282            raise SchemaError(f"Unknown column type '{column_type}'")
283
284        return exp.DataType.build("unknown")
285
286    def _normalize(self, schema: t.Dict) -> t.Dict:
287        """
288        Converts all identifiers in the schema into lowercase, unless they're quoted.
289
290        Args:
291            schema: the schema to normalize.
292
293        Returns:
294            The normalized schema mapping.
295        """
296        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
297
298        normalized_mapping: t.Dict = {}
299        for keys in flattened_schema:
300            columns = nested_get(schema, *zip(keys, keys))
301            assert columns is not None
302
303            normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
304            for column_name, column_type in columns.items():
305                nested_set(
306                    normalized_mapping,
307                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
308                    column_type,
309                )
310
311        return normalized_mapping
312
313    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
314        normalized_table = table.copy()
315
316        for arg in TABLE_ARGS:
317            value = normalized_table.args.get(arg)
318            if isinstance(value, (str, exp.Identifier)):
319                normalized_table.set(arg, self._normalize_name(value, dialect=dialect))
320
321        return normalized_table
322
323    def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
324        dialect = dialect or self.dialect
325
326        try:
327            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
328        except ParseError:
329            return name if isinstance(name, str) else name.name
330
331        return identifier.name if identifier.quoted else identifier.name.lower()
332
333    def _depth(self) -> int:
334        # The columns themselves are a mapping, but we don't want to include those
335        return super()._depth() - 1
336
337    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
338        if isinstance(table, exp.Table):
339            return table
340
341        dialect = dialect or self.dialect
342        parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table)
343
344        if not parsed_table:
345            in_dialect = f" in dialect {dialect}" if dialect else ""
346            raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.")
347
348        return parsed_table
349
350    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
351        """
352        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
353
354        Args:
355            schema_type: the type we want to convert.
356            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
357
358        Returns:
359            The resulting expression type.
360        """
361        if schema_type not in self._type_mapping_cache:
362            dialect = dialect or self.dialect
363
364            try:
365                expression = exp.DataType.build(schema_type, dialect=dialect)
366                self._type_mapping_cache[schema_type] = expression
367            except AttributeError:
368                in_dialect = f" in dialect {dialect}" if dialect else ""
369                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
370
371        return self._type_mapping_cache[schema_type]
372
373
374def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
375    if isinstance(schema, Schema):
376        return schema
377
378    return MappingSchema(schema, dialect=dialect)
379
380
381def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
382    if mapping is None:
383        return {}
384    elif isinstance(mapping, dict):
385        return mapping
386    elif isinstance(mapping, str):
387        col_name_type_strs = [x.strip() for x in mapping.split(",")]
388        return {
389            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
390            for name_type_str in col_name_type_strs
391        }
392    # Check if mapping looks like a DataFrame StructType
393    elif hasattr(mapping, "simpleString"):
394        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
395    elif isinstance(mapping, list):
396        return {x.strip(): None for x in mapping}
397
398    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
399
400
401def flatten_schema(
402    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
403) -> t.List[t.List[str]]:
404    tables = []
405    keys = keys or []
406
407    for k, v in schema.items():
408        if depth >= 2:
409            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
410        elif depth == 1:
411            tables.append(keys + [k])
412
413    return tables
414
415
416def nested_get(
417    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
418) -> t.Optional[t.Any]:
419    """
420    Get a value for a nested dictionary.
421
422    Args:
423        d: the dictionary to search.
424        *path: tuples of (name, key), where:
425            `key` is the key in the dictionary to get.
426            `name` is a string to use in the error if `key` isn't found.
427
428    Returns:
429        The value or None if it doesn't exist.
430    """
431    for name, key in path:
432        d = d.get(key)  # type: ignore
433        if d is None:
434            if raise_on_missing:
435                name = "table" if name == "this" else name
436                raise ValueError(f"Unknown {name}: {key}")
437            return None
438
439    return d
440
441
442def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
443    """
444    In-place set a value for a nested dictionary
445
446    Example:
447        >>> nested_set({}, ["top_key", "second_key"], "value")
448        {'top_key': {'second_key': 'value'}}
449
450        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
451        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
452
453    Args:
454        d: dictionary to update.
455        keys: the keys that makeup the path to `value`.
456        value: the value to set in the dictionary for the given key path.
457
458    Returns:
459        The (possibly) updated dictionary.
460    """
461    if not keys:
462        return d
463
464    if len(keys) == 1:
465        d[keys[0]] = value
466        return d
467
468    subd = d
469    for key in keys[:-1]:
470        if key not in subd:
471            subd = subd.setdefault(key, {})
472        else:
473            subd = subd[key]
474
475    subd[keys[-1]] = value
476    return d
class Schema(abc.ABC):
23class Schema(abc.ABC):
24    """Abstract base class for database schemas"""
25
26    @abc.abstractmethod
27    def add_table(
28        self,
29        table: exp.Table | str,
30        column_mapping: t.Optional[ColumnMapping] = None,
31        dialect: DialectType = None,
32    ) -> None:
33        """
34        Register or update a table. Some implementing classes may require column information to also be provided.
35
36        Args:
37            table: the `Table` expression instance or string representing the table.
38            column_mapping: a column mapping that describes the structure of the table.
39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
40        """
41
42    @abc.abstractmethod
43    def column_names(
44        self,
45        table: exp.Table | str,
46        only_visible: bool = False,
47        dialect: DialectType = None,
48    ) -> t.List[str]:
49        """
50        Get the column names for a table.
51
52        Args:
53            table: the `Table` expression instance.
54            only_visible: whether to include invisible columns.
55            dialect: the SQL dialect that will be used to parse `table` if it's a string.
56
57        Returns:
58            The list of column names.
59        """
60
61    @abc.abstractmethod
62    def get_column_type(
63        self,
64        table: exp.Table | str,
65        column: exp.Column,
66        dialect: DialectType = None,
67    ) -> exp.DataType:
68        """
69        Get the `sqlglot.exp.DataType` type of a column in the schema.
70
71        Args:
72            table: the source table.
73            column: the target column.
74            dialect: the SQL dialect that will be used to parse `table` if it's a string.
75
76        Returns:
77            The resulting column type.
78        """
79
80    @property
81    @abc.abstractmethod
82    def supported_table_args(self) -> t.Tuple[str, ...]:
83        """
84        Table arguments this schema support, e.g. `("this", "db", "catalog")`
85        """
86
87    @property
88    def empty(self) -> bool:
89        """Returns whether or not the schema is empty."""
90        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
26    @abc.abstractmethod
27    def add_table(
28        self,
29        table: exp.Table | str,
30        column_mapping: t.Optional[ColumnMapping] = None,
31        dialect: DialectType = None,
32    ) -> None:
33        """
34        Register or update a table. Some implementing classes may require column information to also be provided.
35
36        Args:
37            table: the `Table` expression instance or string representing the table.
38            column_mapping: a column mapping that describes the structure of the table.
39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
40        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
42    @abc.abstractmethod
43    def column_names(
44        self,
45        table: exp.Table | str,
46        only_visible: bool = False,
47        dialect: DialectType = None,
48    ) -> t.List[str]:
49        """
50        Get the column names for a table.
51
52        Args:
53            table: the `Table` expression instance.
54            only_visible: whether to include invisible columns.
55            dialect: the SQL dialect that will be used to parse `table` if it's a string.
56
57        Returns:
58            The list of column names.
59        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
61    @abc.abstractmethod
62    def get_column_type(
63        self,
64        table: exp.Table | str,
65        column: exp.Column,
66        dialect: DialectType = None,
67    ) -> exp.DataType:
68        """
69        Get the `sqlglot.exp.DataType` type of a column in the schema.
70
71        Args:
72            table: the source table.
73            column: the target column.
74            dialect: the SQL dialect that will be used to parse `table` if it's a string.
75
76        Returns:
77            The resulting column type.
78        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
 93class AbstractMappingSchema(t.Generic[T]):
 94    def __init__(
 95        self,
 96        mapping: t.Optional[t.Dict] = None,
 97    ) -> None:
 98        self.mapping = mapping or {}
 99        self.mapping_trie = new_trie(
100            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
101        )
102        self._supported_table_args: t.Tuple[str, ...] = tuple()
103
104    @property
105    def empty(self) -> bool:
106        return not self.mapping
107
108    def _depth(self) -> int:
109        return dict_depth(self.mapping)
110
111    @property
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        if not self._supported_table_args and self.mapping:
114            depth = self._depth()
115
116            if not depth:  # None
117                self._supported_table_args = tuple()
118            elif 1 <= depth <= 3:
119                self._supported_table_args = TABLE_ARGS[:depth]
120            else:
121                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
122
123        return self._supported_table_args
124
125    def table_parts(self, table: exp.Table) -> t.List[str]:
126        if isinstance(table.this, exp.ReadCSV):
127            return [table.this.name]
128        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
129
130    def find(
131        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
132    ) -> t.Optional[T]:
133        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
134        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
135
136        if value == 0:
137            return None
138
139        if value == 1:
140            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
141
142            if len(possibilities) == 1:
143                parts.extend(possibilities[0])
144            else:
145                message = ", ".join(".".join(parts) for parts in possibilities)
146                if raise_on_missing:
147                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
148                return None
149
150        return self.nested_get(parts, raise_on_missing=raise_on_missing)
151
152    def nested_get(
153        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
154    ) -> t.Optional[t.Any]:
155        return nested_get(
156            d or self.mapping,
157            *zip(self.supported_table_args, reversed(parts)),
158            raise_on_missing=raise_on_missing,
159        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
 94    def __init__(
 95        self,
 96        mapping: t.Optional[t.Dict] = None,
 97    ) -> None:
 98        self.mapping = mapping or {}
 99        self.mapping_trie = new_trie(
100            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
101        )
102        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
125    def table_parts(self, table: exp.Table) -> t.List[str]:
126        if isinstance(table.this, exp.ReadCSV):
127            return [table.this.name]
128        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
130    def find(
131        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
132    ) -> t.Optional[T]:
133        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
134        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
135
136        if value == 0:
137            return None
138
139        if value == 1:
140            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
141
142            if len(possibilities) == 1:
143                parts.extend(possibilities[0])
144            else:
145                message = ", ".join(".".join(parts) for parts in possibilities)
146                if raise_on_missing:
147                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
148                return None
149
150        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
152    def nested_get(
153        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
154    ) -> t.Optional[t.Any]:
155        return nested_get(
156            d or self.mapping,
157            *zip(self.supported_table_args, reversed(parts)),
158            raise_on_missing=raise_on_missing,
159        )
162class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
163    """
164    Schema based on a nested mapping.
165
166    Args:
167        schema: Mapping in one of the following forms:
168            1. {table: {col: type}}
169            2. {db: {table: {col: type}}}
170            3. {catalog: {db: {table: {col: type}}}}
171            4. None - Tables will be added later
172        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
173            are assumed to be visible. The nesting should mirror that of the schema:
174            1. {table: set(*cols)}}
175            2. {db: {table: set(*cols)}}}
176            3. {catalog: {db: {table: set(*cols)}}}}
177        dialect: The dialect to be used for custom type mappings & parsing string arguments.
178    """
179
180    def __init__(
181        self,
182        schema: t.Optional[t.Dict] = None,
183        visible: t.Optional[t.Dict] = None,
184        dialect: DialectType = None,
185    ) -> None:
186        self.dialect = dialect
187        self.visible = visible or {}
188        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
189
190        super().__init__(self._normalize(schema or {}))
191
192    @classmethod
193    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
194        return MappingSchema(
195            schema=mapping_schema.mapping,
196            visible=mapping_schema.visible,
197            dialect=mapping_schema.dialect,
198        )
199
200    def copy(self, **kwargs) -> MappingSchema:
201        return MappingSchema(
202            **{  # type: ignore
203                "schema": self.mapping.copy(),
204                "visible": self.visible.copy(),
205                "dialect": self.dialect,
206                **kwargs,
207            }
208        )
209
210    def add_table(
211        self,
212        table: exp.Table | str,
213        column_mapping: t.Optional[ColumnMapping] = None,
214        dialect: DialectType = None,
215    ) -> None:
216        """
217        Register or update a table. Updates are only performed if a new column mapping is provided.
218
219        Args:
220            table: the `Table` expression instance or string representing the table.
221            column_mapping: a column mapping that describes the structure of the table.
222            dialect: the SQL dialect that will be used to parse `table` if it's a string.
223        """
224        normalized_table = self._normalize_table(
225            self._ensure_table(table, dialect=dialect), dialect=dialect
226        )
227        normalized_column_mapping = {
228            self._normalize_name(key, dialect=dialect): value
229            for key, value in ensure_column_mapping(column_mapping).items()
230        }
231
232        schema = self.find(normalized_table, raise_on_missing=False)
233        if schema and not normalized_column_mapping:
234            return
235
236        parts = self.table_parts(normalized_table)
237
238        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
239        new_trie([parts], self.mapping_trie)
240
241    def column_names(
242        self,
243        table: exp.Table | str,
244        only_visible: bool = False,
245        dialect: DialectType = None,
246    ) -> t.List[str]:
247        normalized_table = self._normalize_table(
248            self._ensure_table(table, dialect=dialect), dialect=dialect
249        )
250
251        schema = self.find(normalized_table)
252        if schema is None:
253            return []
254
255        if not only_visible or not self.visible:
256            return list(schema)
257
258        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
259        return [col for col in schema if col in visible]
260
261    def get_column_type(
262        self,
263        table: exp.Table | str,
264        column: exp.Column,
265        dialect: DialectType = None,
266    ) -> exp.DataType:
267        normalized_table = self._normalize_table(
268            self._ensure_table(table, dialect=dialect), dialect=dialect
269        )
270        normalized_column_name = self._normalize_name(
271            column if isinstance(column, str) else column.this, dialect=dialect
272        )
273
274        table_schema = self.find(normalized_table, raise_on_missing=False)
275        if table_schema:
276            column_type = table_schema.get(normalized_column_name)
277
278            if isinstance(column_type, exp.DataType):
279                return column_type
280            elif isinstance(column_type, str):
281                return self._to_data_type(column_type.upper(), dialect=dialect)
282
283            raise SchemaError(f"Unknown column type '{column_type}'")
284
285        return exp.DataType.build("unknown")
286
287    def _normalize(self, schema: t.Dict) -> t.Dict:
288        """
289        Converts all identifiers in the schema into lowercase, unless they're quoted.
290
291        Args:
292            schema: the schema to normalize.
293
294        Returns:
295            The normalized schema mapping.
296        """
297        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
298
299        normalized_mapping: t.Dict = {}
300        for keys in flattened_schema:
301            columns = nested_get(schema, *zip(keys, keys))
302            assert columns is not None
303
304            normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
305            for column_name, column_type in columns.items():
306                nested_set(
307                    normalized_mapping,
308                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
309                    column_type,
310                )
311
312        return normalized_mapping
313
314    def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
315        normalized_table = table.copy()
316
317        for arg in TABLE_ARGS:
318            value = normalized_table.args.get(arg)
319            if isinstance(value, (str, exp.Identifier)):
320                normalized_table.set(arg, self._normalize_name(value, dialect=dialect))
321
322        return normalized_table
323
324    def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
325        dialect = dialect or self.dialect
326
327        try:
328            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
329        except ParseError:
330            return name if isinstance(name, str) else name.name
331
332        return identifier.name if identifier.quoted else identifier.name.lower()
333
334    def _depth(self) -> int:
335        # The columns themselves are a mapping, but we don't want to include those
336        return super()._depth() - 1
337
338    def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
339        if isinstance(table, exp.Table):
340            return table
341
342        dialect = dialect or self.dialect
343        parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table)
344
345        if not parsed_table:
346            in_dialect = f" in dialect {dialect}" if dialect else ""
347            raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.")
348
349        return parsed_table
350
351    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
352        """
353        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
354
355        Args:
356            schema_type: the type we want to convert.
357            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
358
359        Returns:
360            The resulting expression type.
361        """
362        if schema_type not in self._type_mapping_cache:
363            dialect = dialect or self.dialect
364
365            try:
366                expression = exp.DataType.build(schema_type, dialect=dialect)
367                self._type_mapping_cache[schema_type] = expression
368            except AttributeError:
369                in_dialect = f" in dialect {dialect}" if dialect else ""
370                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
371
372        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None)
180    def __init__(
181        self,
182        schema: t.Optional[t.Dict] = None,
183        visible: t.Optional[t.Dict] = None,
184        dialect: DialectType = None,
185    ) -> None:
186        self.dialect = dialect
187        self.visible = visible or {}
188        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
189
190        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
192    @classmethod
193    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
194        return MappingSchema(
195            schema=mapping_schema.mapping,
196            visible=mapping_schema.visible,
197            dialect=mapping_schema.dialect,
198        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
200    def copy(self, **kwargs) -> MappingSchema:
201        return MappingSchema(
202            **{  # type: ignore
203                "schema": self.mapping.copy(),
204                "visible": self.visible.copy(),
205                "dialect": self.dialect,
206                **kwargs,
207            }
208        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> None:
210    def add_table(
211        self,
212        table: exp.Table | str,
213        column_mapping: t.Optional[ColumnMapping] = None,
214        dialect: DialectType = None,
215    ) -> None:
216        """
217        Register or update a table. Updates are only performed if a new column mapping is provided.
218
219        Args:
220            table: the `Table` expression instance or string representing the table.
221            column_mapping: a column mapping that describes the structure of the table.
222            dialect: the SQL dialect that will be used to parse `table` if it's a string.
223        """
224        normalized_table = self._normalize_table(
225            self._ensure_table(table, dialect=dialect), dialect=dialect
226        )
227        normalized_column_mapping = {
228            self._normalize_name(key, dialect=dialect): value
229            for key, value in ensure_column_mapping(column_mapping).items()
230        }
231
232        schema = self.find(normalized_table, raise_on_missing=False)
233        if schema and not normalized_column_mapping:
234            return
235
236        parts = self.table_parts(normalized_table)
237
238        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
239        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> List[str]:
241    def column_names(
242        self,
243        table: exp.Table | str,
244        only_visible: bool = False,
245        dialect: DialectType = None,
246    ) -> t.List[str]:
247        normalized_table = self._normalize_table(
248            self._ensure_table(table, dialect=dialect), dialect=dialect
249        )
250
251        schema = self.find(normalized_table)
252        if schema is None:
253            return []
254
255        if not only_visible or not self.visible:
256            return list(schema)
257
258        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
259        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.DataType:
261    def get_column_type(
262        self,
263        table: exp.Table | str,
264        column: exp.Column,
265        dialect: DialectType = None,
266    ) -> exp.DataType:
267        normalized_table = self._normalize_table(
268            self._ensure_table(table, dialect=dialect), dialect=dialect
269        )
270        normalized_column_name = self._normalize_name(
271            column if isinstance(column, str) else column.this, dialect=dialect
272        )
273
274        table_schema = self.find(normalized_table, raise_on_missing=False)
275        if table_schema:
276            column_type = table_schema.get(normalized_column_name)
277
278            if isinstance(column_type, exp.DataType):
279                return column_type
280            elif isinstance(column_type, str):
281                return self._to_data_type(column_type.upper(), dialect=dialect)
282
283            raise SchemaError(f"Unknown column type '{column_type}'")
284
285        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
Returns:

The resulting column type.

def ensure_schema( schema: Any, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.schema.Schema:
375def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
376    if isinstance(schema, Schema):
377        return schema
378
379    return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
382def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
383    if mapping is None:
384        return {}
385    elif isinstance(mapping, dict):
386        return mapping
387    elif isinstance(mapping, str):
388        col_name_type_strs = [x.strip() for x in mapping.split(",")]
389        return {
390            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
391            for name_type_str in col_name_type_strs
392        }
393    # Check if mapping looks like a DataFrame StructType
394    elif hasattr(mapping, "simpleString"):
395        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
396    elif isinstance(mapping, list):
397        return {x.strip(): None for x in mapping}
398
399    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
402def flatten_schema(
403    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
404) -> t.List[t.List[str]]:
405    tables = []
406    keys = keys or []
407
408    for k, v in schema.items():
409        if depth >= 2:
410            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
411        elif depth == 1:
412            tables.append(keys + [k])
413
414    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
417def nested_get(
418    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
419) -> t.Optional[t.Any]:
420    """
421    Get a value for a nested dictionary.
422
423    Args:
424        d: the dictionary to search.
425        *path: tuples of (name, key), where:
426            `key` is the key in the dictionary to get.
427            `name` is a string to use in the error if `key` isn't found.
428
429    Returns:
430        The value or None if it doesn't exist.
431    """
432    for name, key in path:
433        d = d.get(key)  # type: ignore
434        if d is None:
435            if raise_on_missing:
436                name = "table" if name == "this" else name
437                raise ValueError(f"Unknown {name}: {key}")
438            return None
439
440    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
443def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
444    """
445    In-place set a value for a nested dictionary
446
447    Example:
448        >>> nested_set({}, ["top_key", "second_key"], "value")
449        {'top_key': {'second_key': 'value'}}
450
451        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
452        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
453
454    Args:
455        d: dictionary to update.
456        keys: the keys that makeup the path to `value`.
457        value: the value to set in the dictionary for the given key path.
458
459    Returns:
460        The (possibly) updated dictionary.
461    """
462    if not keys:
463        return d
464
465    if len(keys) == 1:
466        d[keys[0]] = value
467        return d
468
469    subd = d
470    for key in keys[:-1]:
471        if key not in subd:
472            subd = subd.setdefault(key, {})
473        else:
474            subd = subd[key]
475
476    subd[keys[-1]] = value
477    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.