Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot.errors import ParseError, SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18TABLE_ARGS = ("this", "db", "catalog")
 19
 20T = t.TypeVar("T")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    @abc.abstractmethod
 27    def add_table(
 28        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
 29    ) -> None:
 30        """
 31        Register or update a table. Some implementing classes may require column information to also be provided.
 32
 33        Args:
 34            table: table expression instance or string representing the table.
 35            column_mapping: a column mapping that describes the structure of the table.
 36        """
 37
 38    @abc.abstractmethod
 39    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
 40        """
 41        Get the column names for a table.
 42
 43        Args:
 44            table: the `Table` expression instance.
 45            only_visible: whether to include invisible columns.
 46
 47        Returns:
 48            The list of column names.
 49        """
 50
 51    @abc.abstractmethod
 52    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
 53        """
 54        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
 55
 56        Args:
 57            table: the source table.
 58            column: the target column.
 59
 60        Returns:
 61            The resulting column type.
 62        """
 63
 64    @property
 65    def supported_table_args(self) -> t.Tuple[str, ...]:
 66        """
 67        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 68        """
 69        raise NotImplementedError
 70
 71
 72class AbstractMappingSchema(t.Generic[T]):
 73    def __init__(
 74        self,
 75        mapping: dict | None = None,
 76    ) -> None:
 77        self.mapping = mapping or {}
 78        self.mapping_trie = new_trie(
 79            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
 80        )
 81        self._supported_table_args: t.Tuple[str, ...] = tuple()
 82
 83    def _depth(self) -> int:
 84        return dict_depth(self.mapping)
 85
 86    @property
 87    def supported_table_args(self) -> t.Tuple[str, ...]:
 88        if not self._supported_table_args and self.mapping:
 89            depth = self._depth()
 90
 91            if not depth:  # None
 92                self._supported_table_args = tuple()
 93            elif 1 <= depth <= 3:
 94                self._supported_table_args = TABLE_ARGS[:depth]
 95            else:
 96                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
 97
 98        return self._supported_table_args
 99
100    def table_parts(self, table: exp.Table) -> t.List[str]:
101        if isinstance(table.this, exp.ReadCSV):
102            return [table.this.name]
103        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
104
105    def find(
106        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
107    ) -> t.Optional[T]:
108        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
109        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
110
111        if value == 0:
112            return None
113        elif value == 1:
114            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
115            if len(possibilities) == 1:
116                parts.extend(possibilities[0])
117            else:
118                message = ", ".join(".".join(parts) for parts in possibilities)
119                if raise_on_missing:
120                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
121                return None
122        return self._nested_get(parts, raise_on_missing=raise_on_missing)
123
124    def _nested_get(
125        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
126    ) -> t.Optional[t.Any]:
127        return _nested_get(
128            d or self.mapping,
129            *zip(self.supported_table_args, reversed(parts)),
130            raise_on_missing=raise_on_missing,
131        )
132
133
134class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
135    """
136    Schema based on a nested mapping.
137
138    Args:
139        schema (dict): Mapping in one of the following forms:
140            1. {table: {col: type}}
141            2. {db: {table: {col: type}}}
142            3. {catalog: {db: {table: {col: type}}}}
143            4. None - Tables will be added later
144        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
145            are assumed to be visible. The nesting should mirror that of the schema:
146            1. {table: set(*cols)}}
147            2. {db: {table: set(*cols)}}}
148            3. {catalog: {db: {table: set(*cols)}}}}
149        dialect (str): The dialect to be used for custom type mappings.
150    """
151
152    def __init__(
153        self,
154        schema: t.Optional[t.Dict] = None,
155        visible: t.Optional[t.Dict] = None,
156        dialect: DialectType = None,
157    ) -> None:
158        self.dialect = dialect
159        self.visible = visible or {}
160        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
161        super().__init__(self._normalize(schema or {}))
162
163    @classmethod
164    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
165        return MappingSchema(
166            schema=mapping_schema.mapping,
167            visible=mapping_schema.visible,
168            dialect=mapping_schema.dialect,
169        )
170
171    def copy(self, **kwargs) -> MappingSchema:
172        return MappingSchema(
173            **{  # type: ignore
174                "schema": self.mapping.copy(),
175                "visible": self.visible.copy(),
176                "dialect": self.dialect,
177                **kwargs,
178            }
179        )
180
181    def add_table(
182        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
183    ) -> None:
184        """
185        Register or update a table. Updates are only performed if a new column mapping is provided.
186
187        Args:
188            table: the `Table` expression instance or string representing the table.
189            column_mapping: a column mapping that describes the structure of the table.
190        """
191        normalized_table = self._normalize_table(self._ensure_table(table))
192        normalized_column_mapping = {
193            self._normalize_name(key): value
194            for key, value in ensure_column_mapping(column_mapping).items()
195        }
196
197        schema = self.find(normalized_table, raise_on_missing=False)
198        if schema and not normalized_column_mapping:
199            return
200
201        parts = self.table_parts(normalized_table)
202
203        _nested_set(
204            self.mapping,
205            tuple(reversed(parts)),
206            normalized_column_mapping,
207        )
208        new_trie([parts], self.mapping_trie)
209
210    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
211        table_ = self._normalize_table(self._ensure_table(table))
212        schema = self.find(table_)
213
214        if schema is None:
215            return []
216
217        if not only_visible or not self.visible:
218            return list(schema)
219
220        visible = self._nested_get(self.table_parts(table_), self.visible)
221        return [col for col in schema if col in visible]  # type: ignore
222
223    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
224        column_name = self._normalize_name(column if isinstance(column, str) else column.this)
225        table_ = self._normalize_table(self._ensure_table(table))
226
227        table_schema = self.find(table_, raise_on_missing=False)
228        if table_schema:
229            column_type = table_schema.get(column_name)
230
231            if isinstance(column_type, exp.DataType):
232                return column_type
233            elif isinstance(column_type, str):
234                return self._to_data_type(column_type.upper())
235            raise SchemaError(f"Unknown column type '{column_type}'")
236
237        return exp.DataType.build("unknown")
238
239    def _normalize(self, schema: t.Dict) -> t.Dict:
240        """
241        Converts all identifiers in the schema into lowercase, unless they're quoted.
242
243        Args:
244            schema: the schema to normalize.
245
246        Returns:
247            The normalized schema mapping.
248        """
249        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
250
251        normalized_mapping: t.Dict = {}
252        for keys in flattened_schema:
253            columns = _nested_get(schema, *zip(keys, keys))
254            assert columns is not None
255
256            normalized_keys = [self._normalize_name(key) for key in keys]
257            for column_name, column_type in columns.items():
258                _nested_set(
259                    normalized_mapping,
260                    normalized_keys + [self._normalize_name(column_name)],
261                    column_type,
262                )
263
264        return normalized_mapping
265
266    def _normalize_table(self, table: exp.Table) -> exp.Table:
267        normalized_table = table.copy()
268        for arg in TABLE_ARGS:
269            value = normalized_table.args.get(arg)
270            if isinstance(value, (str, exp.Identifier)):
271                normalized_table.set(arg, self._normalize_name(value))
272
273        return normalized_table
274
275    def _normalize_name(self, name: str | exp.Identifier) -> str:
276        try:
277            identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
278        except ParseError:
279            return name if isinstance(name, str) else name.name
280
281        return identifier.name if identifier.quoted else identifier.name.lower()
282
283    def _depth(self) -> int:
284        # The columns themselves are a mapping, but we don't want to include those
285        return super()._depth() - 1
286
287    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
288        if isinstance(table, exp.Table):
289            return table
290
291        table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
292        if not table_:
293            raise SchemaError(f"Not a valid table '{table}'")
294
295        return table_
296
297    def _to_data_type(self, schema_type: str) -> exp.DataType:
298        """
299        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
300
301        Args:
302            schema_type: the type we want to convert.
303
304        Returns:
305            The resulting expression type.
306        """
307        if schema_type not in self._type_mapping_cache:
308            try:
309                expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
310                if expression is None:
311                    raise ValueError(f"Could not parse {schema_type}")
312                self._type_mapping_cache[schema_type] = expression  # type: ignore
313            except AttributeError:
314                raise SchemaError(f"Failed to convert type {schema_type}")
315
316        return self._type_mapping_cache[schema_type]
317
318
319def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
320    if isinstance(schema, Schema):
321        return schema
322
323    return MappingSchema(schema, dialect=dialect)
324
325
326def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
327    if isinstance(mapping, dict):
328        return mapping
329    elif isinstance(mapping, str):
330        col_name_type_strs = [x.strip() for x in mapping.split(",")]
331        return {
332            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
333            for name_type_str in col_name_type_strs
334        }
335    # Check if mapping looks like a DataFrame StructType
336    elif hasattr(mapping, "simpleString"):
337        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
338    elif isinstance(mapping, list):
339        return {x.strip(): None for x in mapping}
340    elif mapping is None:
341        return {}
342    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
343
344
345def flatten_schema(
346    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
347) -> t.List[t.List[str]]:
348    tables = []
349    keys = keys or []
350
351    for k, v in schema.items():
352        if depth >= 2:
353            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
354        elif depth == 1:
355            tables.append(keys + [k])
356    return tables
357
358
359def _nested_get(
360    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
361) -> t.Optional[t.Any]:
362    """
363    Get a value for a nested dictionary.
364
365    Args:
366        d: the dictionary to search.
367        *path: tuples of (name, key), where:
368            `key` is the key in the dictionary to get.
369            `name` is a string to use in the error if `key` isn't found.
370
371    Returns:
372        The value or None if it doesn't exist.
373    """
374    for name, key in path:
375        d = d.get(key)  # type: ignore
376        if d is None:
377            if raise_on_missing:
378                name = "table" if name == "this" else name
379                raise ValueError(f"Unknown {name}: {key}")
380            return None
381    return d
382
383
384def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
385    """
386    In-place set a value for a nested dictionary
387
388    Example:
389        >>> _nested_set({}, ["top_key", "second_key"], "value")
390        {'top_key': {'second_key': 'value'}}
391
392        >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
393        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
394
395    Args:
396        d: dictionary to update.
397        keys: the keys that makeup the path to `value`.
398        value: the value to set in the dictionary for the given key path.
399
400    Returns:
401        The (possibly) updated dictionary.
402    """
403    if not keys:
404        return d
405
406    if len(keys) == 1:
407        d[keys[0]] = value
408        return d
409
410    subd = d
411    for key in keys[:-1]:
412        if key not in subd:
413            subd = subd.setdefault(key, {})
414        else:
415            subd = subd[key]
416
417    subd[keys[-1]] = value
418    return d
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    @abc.abstractmethod
28    def add_table(
29        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
30    ) -> None:
31        """
32        Register or update a table. Some implementing classes may require column information to also be provided.
33
34        Args:
35            table: table expression instance or string representing the table.
36            column_mapping: a column mapping that describes the structure of the table.
37        """
38
39    @abc.abstractmethod
40    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
41        """
42        Get the column names for a table.
43
44        Args:
45            table: the `Table` expression instance.
46            only_visible: whether to include invisible columns.
47
48        Returns:
49            The list of column names.
50        """
51
52    @abc.abstractmethod
53    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
54        """
55        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
56
57        Args:
58            table: the source table.
59            column: the target column.
60
61        Returns:
62            The resulting column type.
63        """
64
65    @property
66    def supported_table_args(self) -> t.Tuple[str, ...]:
67        """
68        Table arguments this schema support, e.g. `("this", "db", "catalog")`
69        """
70        raise NotImplementedError

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None) -> None:
27    @abc.abstractmethod
28    def add_table(
29        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
30    ) -> None:
31        """
32        Register or update a table. Some implementing classes may require column information to also be provided.
33
34        Args:
35            table: table expression instance or string representing the table.
36            column_mapping: a column mapping that describes the structure of the table.
37        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False) -> List[str]:
39    @abc.abstractmethod
40    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
41        """
42        Get the column names for a table.
43
44        Args:
45            table: the `Table` expression instance.
46            only_visible: whether to include invisible columns.
47
48        Returns:
49            The list of column names.
50        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column) -> sqlglot.expressions.DataType:
52    @abc.abstractmethod
53    def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
54        """
55        Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
56
57        Args:
58            table: the source table.
59            column: the target column.
60
61        Returns:
62            The resulting column type.
63        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

class AbstractMappingSchema(typing.Generic[~T]):
 73class AbstractMappingSchema(t.Generic[T]):
 74    def __init__(
 75        self,
 76        mapping: dict | None = None,
 77    ) -> None:
 78        self.mapping = mapping or {}
 79        self.mapping_trie = new_trie(
 80            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
 81        )
 82        self._supported_table_args: t.Tuple[str, ...] = tuple()
 83
 84    def _depth(self) -> int:
 85        return dict_depth(self.mapping)
 86
 87    @property
 88    def supported_table_args(self) -> t.Tuple[str, ...]:
 89        if not self._supported_table_args and self.mapping:
 90            depth = self._depth()
 91
 92            if not depth:  # None
 93                self._supported_table_args = tuple()
 94            elif 1 <= depth <= 3:
 95                self._supported_table_args = TABLE_ARGS[:depth]
 96            else:
 97                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
 98
 99        return self._supported_table_args
100
101    def table_parts(self, table: exp.Table) -> t.List[str]:
102        if isinstance(table.this, exp.ReadCSV):
103            return [table.this.name]
104        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
105
106    def find(
107        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
108    ) -> t.Optional[T]:
109        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
110        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
111
112        if value == 0:
113            return None
114        elif value == 1:
115            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
116            if len(possibilities) == 1:
117                parts.extend(possibilities[0])
118            else:
119                message = ", ".join(".".join(parts) for parts in possibilities)
120                if raise_on_missing:
121                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
122                return None
123        return self._nested_get(parts, raise_on_missing=raise_on_missing)
124
125    def _nested_get(
126        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
127    ) -> t.Optional[t.Any]:
128        return _nested_get(
129            d or self.mapping,
130            *zip(self.supported_table_args, reversed(parts)),
131            raise_on_missing=raise_on_missing,
132        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: dict | None = None)
74    def __init__(
75        self,
76        mapping: dict | None = None,
77    ) -> None:
78        self.mapping = mapping or {}
79        self.mapping_trie = new_trie(
80            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
81        )
82        self._supported_table_args: t.Tuple[str, ...] = tuple()
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
101    def table_parts(self, table: exp.Table) -> t.List[str]:
102        if isinstance(table.this, exp.ReadCSV):
103            return [table.this.name]
104        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
106    def find(
107        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
108    ) -> t.Optional[T]:
109        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
110        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
111
112        if value == 0:
113            return None
114        elif value == 1:
115            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
116            if len(possibilities) == 1:
117                parts.extend(possibilities[0])
118            else:
119                message = ", ".join(".".join(parts) for parts in possibilities)
120                if raise_on_missing:
121                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
122                return None
123        return self._nested_get(parts, raise_on_missing=raise_on_missing)
135class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
136    """
137    Schema based on a nested mapping.
138
139    Args:
140        schema (dict): Mapping in one of the following forms:
141            1. {table: {col: type}}
142            2. {db: {table: {col: type}}}
143            3. {catalog: {db: {table: {col: type}}}}
144            4. None - Tables will be added later
145        visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
146            are assumed to be visible. The nesting should mirror that of the schema:
147            1. {table: set(*cols)}}
148            2. {db: {table: set(*cols)}}}
149            3. {catalog: {db: {table: set(*cols)}}}}
150        dialect (str): The dialect to be used for custom type mappings.
151    """
152
153    def __init__(
154        self,
155        schema: t.Optional[t.Dict] = None,
156        visible: t.Optional[t.Dict] = None,
157        dialect: DialectType = None,
158    ) -> None:
159        self.dialect = dialect
160        self.visible = visible or {}
161        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
162        super().__init__(self._normalize(schema or {}))
163
164    @classmethod
165    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
166        return MappingSchema(
167            schema=mapping_schema.mapping,
168            visible=mapping_schema.visible,
169            dialect=mapping_schema.dialect,
170        )
171
172    def copy(self, **kwargs) -> MappingSchema:
173        return MappingSchema(
174            **{  # type: ignore
175                "schema": self.mapping.copy(),
176                "visible": self.visible.copy(),
177                "dialect": self.dialect,
178                **kwargs,
179            }
180        )
181
182    def add_table(
183        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
184    ) -> None:
185        """
186        Register or update a table. Updates are only performed if a new column mapping is provided.
187
188        Args:
189            table: the `Table` expression instance or string representing the table.
190            column_mapping: a column mapping that describes the structure of the table.
191        """
192        normalized_table = self._normalize_table(self._ensure_table(table))
193        normalized_column_mapping = {
194            self._normalize_name(key): value
195            for key, value in ensure_column_mapping(column_mapping).items()
196        }
197
198        schema = self.find(normalized_table, raise_on_missing=False)
199        if schema and not normalized_column_mapping:
200            return
201
202        parts = self.table_parts(normalized_table)
203
204        _nested_set(
205            self.mapping,
206            tuple(reversed(parts)),
207            normalized_column_mapping,
208        )
209        new_trie([parts], self.mapping_trie)
210
211    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
212        table_ = self._normalize_table(self._ensure_table(table))
213        schema = self.find(table_)
214
215        if schema is None:
216            return []
217
218        if not only_visible or not self.visible:
219            return list(schema)
220
221        visible = self._nested_get(self.table_parts(table_), self.visible)
222        return [col for col in schema if col in visible]  # type: ignore
223
224    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
225        column_name = self._normalize_name(column if isinstance(column, str) else column.this)
226        table_ = self._normalize_table(self._ensure_table(table))
227
228        table_schema = self.find(table_, raise_on_missing=False)
229        if table_schema:
230            column_type = table_schema.get(column_name)
231
232            if isinstance(column_type, exp.DataType):
233                return column_type
234            elif isinstance(column_type, str):
235                return self._to_data_type(column_type.upper())
236            raise SchemaError(f"Unknown column type '{column_type}'")
237
238        return exp.DataType.build("unknown")
239
240    def _normalize(self, schema: t.Dict) -> t.Dict:
241        """
242        Converts all identifiers in the schema into lowercase, unless they're quoted.
243
244        Args:
245            schema: the schema to normalize.
246
247        Returns:
248            The normalized schema mapping.
249        """
250        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
251
252        normalized_mapping: t.Dict = {}
253        for keys in flattened_schema:
254            columns = _nested_get(schema, *zip(keys, keys))
255            assert columns is not None
256
257            normalized_keys = [self._normalize_name(key) for key in keys]
258            for column_name, column_type in columns.items():
259                _nested_set(
260                    normalized_mapping,
261                    normalized_keys + [self._normalize_name(column_name)],
262                    column_type,
263                )
264
265        return normalized_mapping
266
267    def _normalize_table(self, table: exp.Table) -> exp.Table:
268        normalized_table = table.copy()
269        for arg in TABLE_ARGS:
270            value = normalized_table.args.get(arg)
271            if isinstance(value, (str, exp.Identifier)):
272                normalized_table.set(arg, self._normalize_name(value))
273
274        return normalized_table
275
276    def _normalize_name(self, name: str | exp.Identifier) -> str:
277        try:
278            identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
279        except ParseError:
280            return name if isinstance(name, str) else name.name
281
282        return identifier.name if identifier.quoted else identifier.name.lower()
283
284    def _depth(self) -> int:
285        # The columns themselves are a mapping, but we don't want to include those
286        return super()._depth() - 1
287
288    def _ensure_table(self, table: exp.Table | str) -> exp.Table:
289        if isinstance(table, exp.Table):
290            return table
291
292        table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
293        if not table_:
294            raise SchemaError(f"Not a valid table '{table}'")
295
296        return table_
297
298    def _to_data_type(self, schema_type: str) -> exp.DataType:
299        """
300        Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
301
302        Args:
303            schema_type: the type we want to convert.
304
305        Returns:
306            The resulting expression type.
307        """
308        if schema_type not in self._type_mapping_cache:
309            try:
310                expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
311                if expression is None:
312                    raise ValueError(f"Could not parse {schema_type}")
313                self._type_mapping_cache[schema_type] = expression  # type: ignore
314            except AttributeError:
315                raise SchemaError(f"Failed to convert type {schema_type}")
316
317        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema (dict): Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect (str): The dialect to be used for custom type mappings.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None)
153    def __init__(
154        self,
155        schema: t.Optional[t.Dict] = None,
156        visible: t.Optional[t.Dict] = None,
157        dialect: DialectType = None,
158    ) -> None:
159        self.dialect = dialect
160        self.visible = visible or {}
161        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
162        super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
164    @classmethod
165    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
166        return MappingSchema(
167            schema=mapping_schema.mapping,
168            visible=mapping_schema.visible,
169            dialect=mapping_schema.dialect,
170        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
172    def copy(self, **kwargs) -> MappingSchema:
173        return MappingSchema(
174            **{  # type: ignore
175                "schema": self.mapping.copy(),
176                "visible": self.visible.copy(),
177                "dialect": self.dialect,
178                **kwargs,
179            }
180        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None) -> None:
182    def add_table(
183        self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
184    ) -> None:
185        """
186        Register or update a table. Updates are only performed if a new column mapping is provided.
187
188        Args:
189            table: the `Table` expression instance or string representing the table.
190            column_mapping: a column mapping that describes the structure of the table.
191        """
192        normalized_table = self._normalize_table(self._ensure_table(table))
193        normalized_column_mapping = {
194            self._normalize_name(key): value
195            for key, value in ensure_column_mapping(column_mapping).items()
196        }
197
198        schema = self.find(normalized_table, raise_on_missing=False)
199        if schema and not normalized_column_mapping:
200            return
201
202        parts = self.table_parts(normalized_table)
203
204        _nested_set(
205            self.mapping,
206            tuple(reversed(parts)),
207            normalized_column_mapping,
208        )
209        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False) -> List[str]:
211    def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
212        table_ = self._normalize_table(self._ensure_table(table))
213        schema = self.find(table_)
214
215        if schema is None:
216            return []
217
218        if not only_visible or not self.visible:
219            return list(schema)
220
221        visible = self._nested_get(self.table_parts(table_), self.visible)
222        return [col for col in schema if col in visible]  # type: ignore

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str) -> sqlglot.expressions.DataType:
224    def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
225        column_name = self._normalize_name(column if isinstance(column, str) else column.this)
226        table_ = self._normalize_table(self._ensure_table(table))
227
228        table_schema = self.find(table_, raise_on_missing=False)
229        if table_schema:
230            column_type = table_schema.get(column_name)
231
232            if isinstance(column_type, exp.DataType):
233                return column_type
234            elif isinstance(column_type, str):
235                return self._to_data_type(column_type.upper())
236            raise SchemaError(f"Unknown column type '{column_type}'")
237
238        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
Returns:

The resulting column type.

def ensure_schema( schema: Any, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.schema.Schema:
320def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
321    if isinstance(schema, Schema):
322        return schema
323
324    return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
327def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
328    if isinstance(mapping, dict):
329        return mapping
330    elif isinstance(mapping, str):
331        col_name_type_strs = [x.strip() for x in mapping.split(",")]
332        return {
333            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
334            for name_type_str in col_name_type_strs
335        }
336    # Check if mapping looks like a DataFrame StructType
337    elif hasattr(mapping, "simpleString"):
338        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}  # type: ignore
339    elif isinstance(mapping, list):
340        return {x.strip(): None for x in mapping}
341    elif mapping is None:
342        return {}
343    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
346def flatten_schema(
347    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
348) -> t.List[t.List[str]]:
349    tables = []
350    keys = keys or []
351
352    for k, v in schema.items():
353        if depth >= 2:
354            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
355        elif depth == 1:
356            tables.append(keys + [k])
357    return tables