Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18TABLE_ARGS = ("this", "db", "catalog")
 19
 20T = t.TypeVar("T")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    @abc.abstractmethod
 27    def add_table(
 28        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
 29    ) -> None:
 30        """
 31        Register or update a table. Some implementing classes may require column information to also be provided.
 32
 33        Args:
 34            table: table expression instance or string representing the table.
 35            column_mapping: a column mapping that describes the structure of the table.
 36        """
 37
 38    @abc.abstractmethod
 39    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
 40        """
 41        Get the column names for a table.
 42
 43        Args:
 44            table: the `Table` expression instance.
 45            only_visible: whether to include invisible columns.
 46
 47        Returns:
 48            The list of column names.
 49        """
 50
 51    @abc.abstractmethod
 52    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
 53        """
 54        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
 55
 56        Args:
 57            table: the source table.
 58            column: the target column.
 59
 60        Returns:
 61            The resulting column type.
 62        """
 63
 64    @property
 65    def supported_table_args(self) -> t.Tuple[str, ...]:
 66        """
 67        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 68        """
 69        raise NotImplementedError
 70
 71
 72class AbstractMappingSchema(t.Generic[T]):
 73    def __init__(
 74        self,
 75        mapping: dict | None = None,
 76    ) -> None:
 77        self.mapping = mapping or {}
 78        self.mapping_trie = self._build_trie(self.mapping)
 79        self._supported_table_args: t.Tuple[str, ...] = tuple()
 80
 81    def _build_trie(self, schema: t.Dict) -> t.Dict:
 82        return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
 83
 84    def _depth(self) -> int:
 85        return dict_depth(self.mapping)
 86
 87    @property
 88    def supported_table_args(self) -> t.Tuple[str, ...]:
 89        if not self._supported_table_args and self.mapping:
 90            depth = self._depth()
 91
 92            if not depth:  # None
 93                self._supported_table_args = tuple()
 94            elif 1 <= depth <= 3:
 95                self._supported_table_args = TABLE_ARGS[:depth]
 96            else:
 97                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
 98
 99        return self._supported_table_args
100
101    def table_parts(self, table: exp.Table) -> t.List[str]:
102        if isinstance(table.this, exp.ReadCSV):
103            return [table.this.name]
104        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
105
106    def find(
107        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
108    ) -> t.Optional[T]:
109        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
110        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
111
112        if value == 0:
113            return None
114        elif value == 1:
115            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
116            if len(possibilities) == 1:
117                parts.extend(possibilities[0])
118            else:
119                message = ", ".join(".".join(parts) for parts in possibilities)
120                if raise_on_missing:
121                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
122                return None
123        return self._nested_get(parts, raise_on_missing=raise_on_missing)
124
125    def _nested_get(
126        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
127    ) -> t.Optional[t.Any]:
128        return _nested_get(
129            d or self.mapping,
130            *zip(self.supported_table_args, reversed(parts)),
131            raise_on_missing=raise_on_missing,
132        )
133
134
135class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
136    """
137    Schema based on a nested mapping.
138
139    Args:
140        schema (dict): Mapping in one of the following forms:
141            1. {table: {col: type}}
142            2. {db: {table: {col: type}}}
143            3. {catalog: {db: {table: {col: type}}}}
144            4. None - Tables will be added later
145        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
146            are assumed to be visible. The nesting should mirror that of the schema:
147            1. {table: set(*cols)}}
148            2. {db: {table: set(*cols)}}}
149            3. {catalog: {db: {table: set(*cols)}}}}
150        dialect (str): The dialect to be used for custom type mappings.
151    """
152
153    def __init__(
154        self,
155        schema: t.Optional[t.Dict] = None,
156        visible: t.Optional[t.Dict] = None,
157        dialect: DialectType = None,
158    ) -> None:
159        self.dialect = dialect
160        self.visible = visible or {}
161        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
162        super().__init__(self._normalize(schema or {}))
163
164    @classmethod
165    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
166        return MappingSchema(
167            schema=mapping_schema.mapping,
168            visible=mapping_schema.visible,
169            dialect=mapping_schema.dialect,
170        )
171
172    def copy(self, **kwargs) -> MappingSchema:
173        return MappingSchema(
174            **{  # type: ignore
175                "schema": self.mapping.copy(),
176                "visible": self.visible.copy(),
177                "dialect": self.dialect,
178                **kwargs,
179            }
180        )
181
182    def _normalize(self, schema: t.Dict) -> t.Dict:
183        """
184        Converts all identifiers in the schema into lowercase, unless they're quoted.
185
186        Args:
187            schema: the schema to normalize.
188
189        Returns:
190            The normalized schema mapping.
191        """
192        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
193
194        normalized_mapping: t.Dict = {}
195        for keys in flattened_schema:
196            columns = _nested_get(schema, *zip(keys, keys))
197            assert columns is not None
198
199            normalized_keys = [self._normalize_name(key) for key in keys]
200            for column_name, column_type in columns.items():
201                _nested_set(
202                    normalized_mapping,
203                    normalized_keys + [self._normalize_name(column_name)],
204                    column_type,
205                )
206
207        return normalized_mapping
208
209    def add_table(
210        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
211    ) -> None:
212        """
213        Register or update a table. Updates are only performed if a new column mapping is provided.
214
215        Args:
216            table: the `Table` expression instance or string representing the table.
217            column_mapping: a column mapping that describes the structure of the table.
218        """
219        table_ = self._ensure_table(table)
220        column_mapping = ensure_column_mapping(column_mapping)
221        schema = self.find(table_, raise_on_missing=False)
222
223        if schema and not column_mapping:
224            return
225
226        _nested_set(
227            self.mapping,
228            list(reversed(self.table_parts(table_))),
229            column_mapping,
230        )
231        self.mapping_trie = self._build_trie(self.mapping)
232
233    def _normalize_name(self, name: str) -> str:
234        try:
235            identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
236                name, read=self.dialect, into=exp.Identifier
237            )
238        except:
239            identifier = exp.to_identifier(name)
240        assert isinstance(identifier, exp.Identifier)
241
242        if identifier.quoted:
243            return identifier.name
244        return identifier.name.lower()
245
246    def _depth(self) -> int:
247        # The columns themselves are a mapping, but we don't want to include those
248        return super()._depth() - 1
249
250    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
251        table_ = exp.to_table(table)
252
253        if not table_:
254            raise SchemaError(f"Not a valid table '{table}'")
255
256        return table_
257
258    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
259        table_ = self._ensure_table(table)
260        schema = self.find(table_)
261
262        if schema is None:
263            return []
264
265        if not only_visible or not self.visible:
266            return list(schema)
267
268        visible = self._nested_get(self.table_parts(table_), self.visible)
269        return [col for col in schema if col in visible]  # type: ignore
270
271    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
272        column_name = column if isinstance(column, str) else column.name
273        table_ = exp.to_table(table)
274        if table_:
275            table_schema = self.find(table_, raise_on_missing=False)
276            if table_schema:
277                column_type = table_schema.get(column_name)
278
279                if isinstance(column_type, exp.DataType):
280                    return column_type
281                elif isinstance(column_type, str):
282                    return self._to_data_type(column_type.upper())
283                raise SchemaError(f"Unknown column type '{column_type}'")
284            return exp.DataType(this=exp.DataType.Type.UNKNOWN)
285        raise SchemaError(f"Could not convert table '{table}'")
286
287    def _to_data_type(self, schema_type: str) -> exp.DataType:
288        """
289        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
290
291        Args:
292            schema_type: the type we want to convert.
293
294        Returns:
295            The resulting expression type.
296        """
297        if schema_type not in self._type_mapping_cache:
298            try:
299                expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
300                if expression is None:
301                    raise ValueError(f"Could not parse {schema_type}")
302                self._type_mapping_cache[schema_type] = expression  # type: ignore
303            except AttributeError:
304                raise SchemaError(f"Failed to convert type {schema_type}")
305
306        return self._type_mapping_cache[schema_type]
307
308
309def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
310    if isinstance(schema, Schema):
311        return schema
312
313    return MappingSchema(schema, dialect=dialect)
314
315
316def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
317    if isinstance(mapping, dict):
318        return mapping
319    elif isinstance(mapping, str):
320        col_name_type_strs = [x.strip() for x in mapping.split(",")]
321        return {
322            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
323            for name_type_str in col_name_type_strs
324        }
325    # Check if mapping looks like a DataFrame StructType
326    elif hasattr(mapping, "simpleString"):
327        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
328    elif isinstance(mapping, list):
329        return {x.strip(): None for x in mapping}
330    elif mapping is None:
331        return {}
332    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
333
334
335def flatten_schema(
336    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
337) -> t.List[t.List[str]]:
338    tables = []
339    keys = keys or []
340
341    for k, v in schema.items():
342        if depth >= 2:
343            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
344        elif depth == 1:
345            tables.append(keys + [k])
346    return tables
347
348
349def _nested_get(
350    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
351) -> t.Optional[t.Any]:
352    """
353    Get a value for a nested dictionary.
354
355    Args:
356        d: the dictionary to search.
357        *path: tuples of (name, key), where:
358            `key` is the key in the dictionary to get.
359            `name` is a string to use in the error if `key` isn't found.
360
361    Returns:
362        The value or None if it doesn't exist.
363    """
364    for name, key in path:
365        d = d.get(key)  # type: ignore
366        if d is None:
367            if raise_on_missing:
368                name = "table" if name == "this" else name
369                raise ValueError(f"Unknown {name}: {key}")
370            return None
371    return d
372
373
374def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
375    """
376    In-place set a value for a nested dictionary
377
378    Example:
379        >>> _nested_set({}, ["top_key", "second_key"], "value")
380        {'top_key': {'second_key': 'value'}}
381
382        >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
383        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
384
385    Args:
386        d: dictionary to update.
387            keys: the keys that makeup the path to `value`.
388            value: the value to set in the dictionary for the given key path.
389
390        Returns:
391                The (possibly) updated dictionary.
392    """
393    if not keys:
394        return d
395
396    if len(keys) == 1:
397        d[keys[0]] = value
398        return d
399
400    subd = d
401    for key in keys[:-1]:
402        if key not in subd:
403            subd = subd.setdefault(key, {})
404        else:
405            subd = subd[key]
406
407    subd[keys[-1]] = value
408    return d
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    @abc.abstractmethod
28    def add_table(
29        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
30    ) -> None:
31        """
32        Register or update a table. Some implementing classes may require column information to also be provided.
33
34        Args:
35            table: table expression instance or string representing the table.
36            column_mapping: a column mapping that describes the structure of the table.
37        """
38
39    @abc.abstractmethod
40    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
41        """
42        Get the column names for a table.
43
44        Args:
45            table: the `Table` expression instance.
46            only_visible: whether to include invisible columns.
47
48        Returns:
49            The list of column names.
50        """
51
52    @abc.abstractmethod
53    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
54        """
55        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
56
57        Args:
58            table: the source table.
59            column: the target column.
60
61        Returns:
62            The resulting column type.
63        """
64
65    @property
66    def supported_table_args(self) -> t.Tuple[str, ...]:
67        """
68        Table arguments this schema support, e.g. `("this", "db", "catalog")`
69        """
70        raise NotImplementedError

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None) -> None:
27    @abc.abstractmethod
28    def add_table(
29        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
30    ) -> None:
31        """
32        Register or update a table. Some implementing classes may require column information to also be provided.
33
34        Args:
35            table: table expression instance or string representing the table.
36            column_mapping: a column mapping that describes the structure of the table.
37        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False) -> List[str]:
39    @abc.abstractmethod
40    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
41        """
42        Get the column names for a table.
43
44        Args:
45            table: the `Table` expression instance.
46            only_visible: whether to include invisible columns.
47
48        Returns:
49            The list of column names.
50        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column) -> sqlglot.expressions.DataType:
52    @abc.abstractmethod
53    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
54        """
55        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
56
57        Args:
58            table: the source table.
59            column: the target column.
60
61        Returns:
62            The resulting column type.
63        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

class AbstractMappingSchema(typing.Generic[~T]):
 73class AbstractMappingSchema(t.Generic[T]):
 74    def __init__(
 75        self,
 76        mapping: dict | None = None,
 77    ) -> None:
 78        self.mapping = mapping or {}
 79        self.mapping_trie = self._build_trie(self.mapping)
 80        self._supported_table_args: t.Tuple[str, ...] = tuple()
 81
 82    def _build_trie(self, schema: t.Dict) -> t.Dict:
 83        return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
 84
 85    def _depth(self) -> int:
 86        return dict_depth(self.mapping)
 87
 88    @property
 89    def supported_table_args(self) -> t.Tuple[str, ...]:
 90        if not self._supported_table_args and self.mapping:
 91            depth = self._depth()
 92
 93            if not depth:  # None
 94                self._supported_table_args = tuple()
 95            elif 1 <= depth <= 3:
 96                self._supported_table_args = TABLE_ARGS[:depth]
 97            else:
 98                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
 99
100        return self._supported_table_args
101
102    def table_parts(self, table: exp.Table) -> t.List[str]:
103        if isinstance(table.this, exp.ReadCSV):
104            return [table.this.name]
105        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
106
107    def find(
108        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
109    ) -> t.Optional[T]:
110        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
111        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
112
113        if value == 0:
114            return None
115        elif value == 1:
116            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
117            if len(possibilities) == 1:
118                parts.extend(possibilities[0])
119            else:
120                message = ", ".join(".".join(parts) for parts in possibilities)
121                if raise_on_missing:
122                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
123                return None
124        return self._nested_get(parts, raise_on_missing=raise_on_missing)
125
126    def _nested_get(
127        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
128    ) -> t.Optional[t.Any]:
129        return _nested_get(
130            d or self.mapping,
131            *zip(self.supported_table_args, reversed(parts)),
132            raise_on_missing=raise_on_missing,
133        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: dict | None = None)
74    def __init__(
75        self,
76        mapping: dict | None = None,
77    ) -> None:
78        self.mapping = mapping or {}
79        self.mapping_trie = self._build_trie(self.mapping)
80        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
102    def table_parts(self, table: exp.Table) -> t.List[str]:
103        if isinstance(table.this, exp.ReadCSV):
104            return [table.this.name]
105        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
107    def find(
108        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
109    ) -> t.Optional[T]:
110        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
111        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
112
113        if value == 0:
114            return None
115        elif value == 1:
116            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
117            if len(possibilities) == 1:
118                parts.extend(possibilities[0])
119            else:
120                message = ", ".join(".".join(parts) for parts in possibilities)
121                if raise_on_missing:
122                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
123                return None
124        return self._nested_get(parts, raise_on_missing=raise_on_missing)
136class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
137    """
138    Schema based on a nested mapping.
139
140    Args:
141        schema (dict): Mapping in one of the following forms:
142            1. {table: {col: type}}
143            2. {db: {table: {col: type}}}
144            3. {catalog: {db: {table: {col: type}}}}
145            4. None - Tables will be added later
146        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
147            are assumed to be visible. The nesting should mirror that of the schema:
148            1. {table: set(*cols)}}
149            2. {db: {table: set(*cols)}}}
150            3. {catalog: {db: {table: set(*cols)}}}}
151        dialect (str): The dialect to be used for custom type mappings.
152    """
153
154    def __init__(
155        self,
156        schema: t.Optional[t.Dict] = None,
157        visible: t.Optional[t.Dict] = None,
158        dialect: DialectType = None,
159    ) -> None:
160        self.dialect = dialect
161        self.visible = visible or {}
162        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
163        super().__init__(self._normalize(schema or {}))
164
165    @classmethod
166    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
167        return MappingSchema(
168            schema=mapping_schema.mapping,
169            visible=mapping_schema.visible,
170            dialect=mapping_schema.dialect,
171        )
172
173    def copy(self, **kwargs) -> MappingSchema:
174        return MappingSchema(
175            **{  # type: ignore
176                "schema": self.mapping.copy(),
177                "visible": self.visible.copy(),
178                "dialect": self.dialect,
179                **kwargs,
180            }
181        )
182
183    def _normalize(self, schema: t.Dict) -> t.Dict:
184        """
185        Converts all identifiers in the schema into lowercase, unless they're quoted.
186
187        Args:
188            schema: the schema to normalize.
189
190        Returns:
191            The normalized schema mapping.
192        """
193        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
194
195        normalized_mapping: t.Dict = {}
196        for keys in flattened_schema:
197            columns = _nested_get(schema, *zip(keys, keys))
198            assert columns is not None
199
200            normalized_keys = [self._normalize_name(key) for key in keys]
201            for column_name, column_type in columns.items():
202                _nested_set(
203                    normalized_mapping,
204                    normalized_keys + [self._normalize_name(column_name)],
205                    column_type,
206                )
207
208        return normalized_mapping
209
210    def add_table(
211        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
212    ) -> None:
213        """
214        Register or update a table. Updates are only performed if a new column mapping is provided.
215
216        Args:
217            table: the `Table` expression instance or string representing the table.
218            column_mapping: a column mapping that describes the structure of the table.
219        """
220        table_ = self._ensure_table(table)
221        column_mapping = ensure_column_mapping(column_mapping)
222        schema = self.find(table_, raise_on_missing=False)
223
224        if schema and not column_mapping:
225            return
226
227        _nested_set(
228            self.mapping,
229            list(reversed(self.table_parts(table_))),
230            column_mapping,
231        )
232        self.mapping_trie = self._build_trie(self.mapping)
233
234    def _normalize_name(self, name: str) -> str:
235        try:
236            identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
237                name, read=self.dialect, into=exp.Identifier
238            )
239        except:
240            identifier = exp.to_identifier(name)
241        assert isinstance(identifier, exp.Identifier)
242
243        if identifier.quoted:
244            return identifier.name
245        return identifier.name.lower()
246
247    def _depth(self) -> int:
248        # The columns themselves are a mapping, but we don't want to include those
249        return super()._depth() - 1
250
251    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
252        table_ = exp.to_table(table)
253
254        if not table_:
255            raise SchemaError(f"Not a valid table '{table}'")
256
257        return table_
258
259    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
260        table_ = self._ensure_table(table)
261        schema = self.find(table_)
262
263        if schema is None:
264            return []
265
266        if not only_visible or not self.visible:
267            return list(schema)
268
269        visible = self._nested_get(self.table_parts(table_), self.visible)
270        return [col for col in schema if col in visible]  # type: ignore
271
272    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
273        column_name = column if isinstance(column, str) else column.name
274        table_ = exp.to_table(table)
275        if table_:
276            table_schema = self.find(table_, raise_on_missing=False)
277            if table_schema:
278                column_type = table_schema.get(column_name)
279
280                if isinstance(column_type, exp.DataType):
281                    return column_type
282                elif isinstance(column_type, str):
283                    return self._to_data_type(column_type.upper())
284                raise SchemaError(f"Unknown column type '{column_type}'")
285            return exp.DataType(this=exp.DataType.Type.UNKNOWN)
286        raise SchemaError(f"Could not convert table '{table}'")
287
288    def _to_data_type(self, schema_type: str) -> exp.DataType:
289        """
290        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
291
292        Args:
293            schema_type: the type we want to convert.
294
295        Returns:
296            The resulting expression type.
297        """
298        if schema_type not in self._type_mapping_cache:
299            try:
300                expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
301                if expression is None:
302                    raise ValueError(f"Could not parse {schema_type}")
303                self._type_mapping_cache[schema_type] = expression  # type: ignore
304            except AttributeError:
305                raise SchemaError(f"Failed to convert type {schema_type}")
306
307        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema (dict): Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect (str): The dialect to be used for custom type mappings.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None)
154    def __init__(
155        self,
156        schema: t.Optional[t.Dict] = None,
157        visible: t.Optional[t.Dict] = None,
158        dialect: DialectType = None,
159    ) -> None:
160        self.dialect = dialect
161        self.visible = visible or {}
162        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
163        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
165    @classmethod
166    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
167        return MappingSchema(
168            schema=mapping_schema.mapping,
169            visible=mapping_schema.visible,
170            dialect=mapping_schema.dialect,
171        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
173    def copy(self, **kwargs) -> MappingSchema:
174        return MappingSchema(
175            **{  # type: ignore
176                "schema": self.mapping.copy(),
177                "visible": self.visible.copy(),
178                "dialect": self.dialect,
179                **kwargs,
180            }
181        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None) -> None:
210    def add_table(
211        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
212    ) -> None:
213        """
214        Register or update a table. Updates are only performed if a new column mapping is provided.
215
216        Args:
217            table: the `Table` expression instance or string representing the table.
218            column_mapping: a column mapping that describes the structure of the table.
219        """
220        table_ = self._ensure_table(table)
221        column_mapping = ensure_column_mapping(column_mapping)
222        schema = self.find(table_, raise_on_missing=False)
223
224        if schema and not column_mapping:
225            return
226
227        _nested_set(
228            self.mapping,
229            list(reversed(self.table_parts(table_))),
230            column_mapping,
231        )
232        self.mapping_trie = self._build_trie(self.mapping)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False) -> List[str]:
259    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
260        table_ = self._ensure_table(table)
261        schema = self.find(table_)
262
263        if schema is None:
264            return []
265
266        if not only_visible or not self.visible:
267            return list(schema)
268
269        visible = self._nested_get(self.table_parts(table_), self.visible)
270        return [col for col in schema if col in visible]  # type: ignore

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str) -> sqlglot.expressions.DataType:
272    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
273        column_name = column if isinstance(column, str) else column.name
274        table_ = exp.to_table(table)
275        if table_:
276            table_schema = self.find(table_, raise_on_missing=False)
277            if table_schema:
278                column_type = table_schema.get(column_name)
279
280                if isinstance(column_type, exp.DataType):
281                    return column_type
282                elif isinstance(column_type, str):
283                    return self._to_data_type(column_type.upper())
284                raise SchemaError(f"Unknown column type '{column_type}'")
285            return exp.DataType(this=exp.DataType.Type.UNKNOWN)
286        raise SchemaError(f"Could not convert table '{table}'")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
Returns:

The resulting column type.

def ensure_schema( schema: Any, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.schema.Schema:
310def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
311    if isinstance(schema, Schema):
312        return schema
313
314    return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]):
317def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
318    if isinstance(mapping, dict):
319        return mapping
320    elif isinstance(mapping, str):
321        col_name_type_strs = [x.strip() for x in mapping.split(",")]
322        return {
323            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
324            for name_type_str in col_name_type_strs
325        }
326    # Check if mapping looks like a DataFrame StructType
327    elif hasattr(mapping, "simpleString"):
328        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
329    elif isinstance(mapping, list):
330        return {x.strip(): None for x in mapping}
331    elif mapping is None:
332        return {}
333    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
336def flatten_schema(
337    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
338) -> t.List[t.List[str]]:
339    tables = []
340    keys = keys or []
341
342    for k, v in schema.items():
343        if depth >= 2:
344            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
345        elif depth == 1:
346            tables.append(keys + [k])
347    return tables