Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import Dialect
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import TrieResult, in_trie, new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 19
 20TABLE_ARGS = ("this", "db", "catalog")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34        normalize: t.Optional[bool] = None,
 35        match_depth: bool = True,
 36    ) -> None:
 37        """
 38        Register or update a table. Some implementing classes may require column information to also be provided.
 39        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 40
 41        Args:
 42            table: the `Table` expression instance or string representing the table.
 43            column_mapping: a column mapping that describes the structure of the table.
 44            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 45            normalize: whether to normalize identifiers according to the dialect of interest.
 46            match_depth: whether to enforce that the table must match the schema's depth or not.
 47        """
 48
 49    @abc.abstractmethod
 50    def column_names(
 51        self,
 52        table: exp.Table | str,
 53        only_visible: bool = False,
 54        dialect: DialectType = None,
 55        normalize: t.Optional[bool] = None,
 56    ) -> t.List[str]:
 57        """
 58        Get the column names for a table.
 59
 60        Args:
 61            table: the `Table` expression instance.
 62            only_visible: whether to include invisible columns.
 63            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 64            normalize: whether to normalize identifiers according to the dialect of interest.
 65
 66        Returns:
 67            The list of column names.
 68        """
 69
 70    @abc.abstractmethod
 71    def get_column_type(
 72        self,
 73        table: exp.Table | str,
 74        column: exp.Column | str,
 75        dialect: DialectType = None,
 76        normalize: t.Optional[bool] = None,
 77    ) -> exp.DataType:
 78        """
 79        Get the `sqlglot.exp.DataType` type of a column in the schema.
 80
 81        Args:
 82            table: the source table.
 83            column: the target column.
 84            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 85            normalize: whether to normalize identifiers according to the dialect of interest.
 86
 87        Returns:
 88            The resulting column type.
 89        """
 90
 91    def has_column(
 92        self,
 93        table: exp.Table | str,
 94        column: exp.Column | str,
 95        dialect: DialectType = None,
 96        normalize: t.Optional[bool] = None,
 97    ) -> bool:
 98        """
 99        Returns whether or not `column` appears in `table`'s schema.
100
101        Args:
102            table: the source table.
103            column: the target column.
104            dialect: the SQL dialect that will be used to parse `table` if it's a string.
105            normalize: whether to normalize identifiers according to the dialect of interest.
106
107        Returns:
108            True if the column appears in the schema, False otherwise.
109        """
110        name = column if isinstance(column, str) else column.name
111        return name in self.column_names(table, dialect=dialect, normalize=normalize)
112
113    @property
114    @abc.abstractmethod
115    def supported_table_args(self) -> t.Tuple[str, ...]:
116        """
117        Table arguments this schema support, e.g. `("this", "db", "catalog")`
118        """
119
120    @property
121    def empty(self) -> bool:
122        """Returns whether or not the schema is empty."""
123        return True
124
125
126class AbstractMappingSchema(t.Generic[T]):
127    def __init__(
128        self,
129        mapping: t.Optional[t.Dict] = None,
130    ) -> None:
131        self.mapping = mapping or {}
132        self.mapping_trie = new_trie(
133            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
134        )
135        self._supported_table_args: t.Tuple[str, ...] = tuple()
136
137    @property
138    def empty(self) -> bool:
139        return not self.mapping
140
141    def depth(self) -> int:
142        return dict_depth(self.mapping)
143
144    @property
145    def supported_table_args(self) -> t.Tuple[str, ...]:
146        if not self._supported_table_args and self.mapping:
147            depth = self.depth()
148
149            if not depth:  # None
150                self._supported_table_args = tuple()
151            elif 1 <= depth <= 3:
152                self._supported_table_args = TABLE_ARGS[:depth]
153            else:
154                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
155
156        return self._supported_table_args
157
158    def table_parts(self, table: exp.Table) -> t.List[str]:
159        if isinstance(table.this, exp.ReadCSV):
160            return [table.this.name]
161        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
162
163    def find(
164        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
165    ) -> t.Optional[T]:
166        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
167        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
168
169        if value == TrieResult.FAILED:
170            return None
171
172        if value == TrieResult.PREFIX:
173            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
174
175            if len(possibilities) == 1:
176                parts.extend(possibilities[0])
177            else:
178                message = ", ".join(".".join(parts) for parts in possibilities)
179                if raise_on_missing:
180                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
181                return None
182
183        return self.nested_get(parts, raise_on_missing=raise_on_missing)
184
185    def nested_get(
186        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
187    ) -> t.Optional[t.Any]:
188        return nested_get(
189            d or self.mapping,
190            *zip(self.supported_table_args, reversed(parts)),
191            raise_on_missing=raise_on_missing,
192        )
193
194
195class MappingSchema(AbstractMappingSchema, Schema):
196    """
197    Schema based on a nested mapping.
198
199    Args:
200        schema: Mapping in one of the following forms:
201            1. {table: {col: type}}
202            2. {db: {table: {col: type}}}
203            3. {catalog: {db: {table: {col: type}}}}
204            4. None - Tables will be added later
205        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
206            are assumed to be visible. The nesting should mirror that of the schema:
207            1. {table: set(*cols)}}
208            2. {db: {table: set(*cols)}}}
209            3. {catalog: {db: {table: set(*cols)}}}}
210        dialect: The dialect to be used for custom type mappings & parsing string arguments.
211        normalize: Whether to normalize identifier names according to the given dialect or not.
212    """
213
214    def __init__(
215        self,
216        schema: t.Optional[t.Dict] = None,
217        visible: t.Optional[t.Dict] = None,
218        dialect: DialectType = None,
219        normalize: bool = True,
220    ) -> None:
221        self.dialect = dialect
222        self.visible = visible or {}
223        self.normalize = normalize
224        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
225        self._depth = 0
226
227        super().__init__(self._normalize(schema or {}))
228
229    @classmethod
230    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
231        return MappingSchema(
232            schema=mapping_schema.mapping,
233            visible=mapping_schema.visible,
234            dialect=mapping_schema.dialect,
235            normalize=mapping_schema.normalize,
236        )
237
238    def copy(self, **kwargs) -> MappingSchema:
239        return MappingSchema(
240            **{  # type: ignore
241                "schema": self.mapping.copy(),
242                "visible": self.visible.copy(),
243                "dialect": self.dialect,
244                "normalize": self.normalize,
245                **kwargs,
246            }
247        )
248
249    def add_table(
250        self,
251        table: exp.Table | str,
252        column_mapping: t.Optional[ColumnMapping] = None,
253        dialect: DialectType = None,
254        normalize: t.Optional[bool] = None,
255        match_depth: bool = True,
256    ) -> None:
257        """
258        Register or update a table. Updates are only performed if a new column mapping is provided.
259        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
260
261        Args:
262            table: the `Table` expression instance or string representing the table.
263            column_mapping: a column mapping that describes the structure of the table.
264            dialect: the SQL dialect that will be used to parse `table` if it's a string.
265            normalize: whether to normalize identifiers according to the dialect of interest.
266            match_depth: whether to enforce that the table must match the schema's depth or not.
267        """
268        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
269
270        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
271            raise SchemaError(
272                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
273                f"schema's nesting level: {self.depth()}."
274            )
275
276        normalized_column_mapping = {
277            self._normalize_name(key, dialect=dialect, normalize=normalize): value
278            for key, value in ensure_column_mapping(column_mapping).items()
279        }
280
281        schema = self.find(normalized_table, raise_on_missing=False)
282        if schema and not normalized_column_mapping:
283            return
284
285        parts = self.table_parts(normalized_table)
286
287        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
288        new_trie([parts], self.mapping_trie)
289
290    def column_names(
291        self,
292        table: exp.Table | str,
293        only_visible: bool = False,
294        dialect: DialectType = None,
295        normalize: t.Optional[bool] = None,
296    ) -> t.List[str]:
297        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
298
299        schema = self.find(normalized_table)
300        if schema is None:
301            return []
302
303        if not only_visible or not self.visible:
304            return list(schema)
305
306        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
307        return [col for col in schema if col in visible]
308
309    def get_column_type(
310        self,
311        table: exp.Table | str,
312        column: exp.Column | str,
313        dialect: DialectType = None,
314        normalize: t.Optional[bool] = None,
315    ) -> exp.DataType:
316        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
317
318        normalized_column_name = self._normalize_name(
319            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
320        )
321
322        table_schema = self.find(normalized_table, raise_on_missing=False)
323        if table_schema:
324            column_type = table_schema.get(normalized_column_name)
325
326            if isinstance(column_type, exp.DataType):
327                return column_type
328            elif isinstance(column_type, str):
329                return self._to_data_type(column_type, dialect=dialect)
330
331        return exp.DataType.build("unknown")
332
333    def has_column(
334        self,
335        table: exp.Table | str,
336        column: exp.Column | str,
337        dialect: DialectType = None,
338        normalize: t.Optional[bool] = None,
339    ) -> bool:
340        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
341
342        normalized_column_name = self._normalize_name(
343            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
344        )
345
346        table_schema = self.find(normalized_table, raise_on_missing=False)
347        return normalized_column_name in table_schema if table_schema else False
348
349    def _normalize(self, schema: t.Dict) -> t.Dict:
350        """
351        Normalizes all identifiers in the schema.
352
353        Args:
354            schema: the schema to normalize.
355
356        Returns:
357            The normalized schema mapping.
358        """
359        normalized_mapping: t.Dict = {}
360        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
361
362        for keys in flattened_schema:
363            columns = nested_get(schema, *zip(keys, keys))
364
365            if not isinstance(columns, dict):
366                raise SchemaError(
367                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
368                )
369
370            normalized_keys = [
371                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
372            ]
373            for column_name, column_type in columns.items():
374                nested_set(
375                    normalized_mapping,
376                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
377                    column_type,
378                )
379
380        return normalized_mapping
381
382    def _normalize_table(
383        self,
384        table: exp.Table | str,
385        dialect: DialectType = None,
386        normalize: t.Optional[bool] = None,
387    ) -> exp.Table:
388        normalized_table = exp.maybe_parse(
389            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
390        )
391
392        for arg in TABLE_ARGS:
393            value = normalized_table.args.get(arg)
394            if isinstance(value, (str, exp.Identifier)):
395                normalized_table.set(
396                    arg,
397                    exp.to_identifier(
398                        self._normalize_name(
399                            value, dialect=dialect, is_table=True, normalize=normalize
400                        )
401                    ),
402                )
403
404        return normalized_table
405
406    def _normalize_name(
407        self,
408        name: str | exp.Identifier,
409        dialect: DialectType = None,
410        is_table: bool = False,
411        normalize: t.Optional[bool] = None,
412    ) -> str:
413        return normalize_name(
414            name,
415            dialect=dialect or self.dialect,
416            is_table=is_table,
417            normalize=self.normalize if normalize is None else normalize,
418        )
419
420    def depth(self) -> int:
421        if not self.empty and not self._depth:
422            # The columns themselves are a mapping, but we don't want to include those
423            self._depth = super().depth() - 1
424        return self._depth
425
426    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
427        """
428        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
429
430        Args:
431            schema_type: the type we want to convert.
432            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
433
434        Returns:
435            The resulting expression type.
436        """
437        if schema_type not in self._type_mapping_cache:
438            dialect = dialect or self.dialect
439            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
440
441            try:
442                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
443                self._type_mapping_cache[schema_type] = expression
444            except AttributeError:
445                in_dialect = f" in dialect {dialect}" if dialect else ""
446                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
447
448        return self._type_mapping_cache[schema_type]
449
450
451def normalize_name(
452    name: str | exp.Identifier,
453    dialect: DialectType = None,
454    is_table: bool = False,
455    normalize: t.Optional[bool] = True,
456) -> str:
457    try:
458        identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
459    except ParseError:
460        return name if isinstance(name, str) else name.name
461
462    name = identifier.name
463    if not normalize:
464        return name
465
466    # This can be useful for normalize_identifier
467    identifier.meta["is_table"] = is_table
468    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
469
470
471def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
472    if isinstance(schema, Schema):
473        return schema
474
475    return MappingSchema(schema, **kwargs)
476
477
478def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
479    if mapping is None:
480        return {}
481    elif isinstance(mapping, dict):
482        return mapping
483    elif isinstance(mapping, str):
484        col_name_type_strs = [x.strip() for x in mapping.split(",")]
485        return {
486            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
487            for name_type_str in col_name_type_strs
488        }
489    # Check if mapping looks like a DataFrame StructType
490    elif hasattr(mapping, "simpleString"):
491        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
492    elif isinstance(mapping, list):
493        return {x.strip(): None for x in mapping}
494
495    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
496
497
498def flatten_schema(
499    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
500) -> t.List[t.List[str]]:
501    tables = []
502    keys = keys or []
503
504    for k, v in schema.items():
505        if depth >= 2:
506            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
507        elif depth == 1:
508            tables.append(keys + [k])
509
510    return tables
511
512
513def nested_get(
514    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
515) -> t.Optional[t.Any]:
516    """
517    Get a value for a nested dictionary.
518
519    Args:
520        d: the dictionary to search.
521        *path: tuples of (name, key), where:
522            `key` is the key in the dictionary to get.
523            `name` is a string to use in the error if `key` isn't found.
524
525    Returns:
526        The value or None if it doesn't exist.
527    """
528    for name, key in path:
529        d = d.get(key)  # type: ignore
530        if d is None:
531            if raise_on_missing:
532                name = "table" if name == "this" else name
533                raise ValueError(f"Unknown {name}: {key}")
534            return None
535
536    return d
537
538
539def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
540    """
541    In-place set a value for a nested dictionary
542
543    Example:
544        >>> nested_set({}, ["top_key", "second_key"], "value")
545        {'top_key': {'second_key': 'value'}}
546
547        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
548        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
549
550    Args:
551        d: dictionary to update.
552        keys: the keys that makeup the path to `value`.
553        value: the value to set in the dictionary for the given key path.
554
555    Returns:
556        The (possibly) updated dictionary.
557    """
558    if not keys:
559        return d
560
561    if len(keys) == 1:
562        d[keys[0]] = value
563        return d
564
565    subd = d
566    for key in keys[:-1]:
567        if key not in subd:
568            subd = subd.setdefault(key, {})
569        else:
570            subd = subd[key]
571
572    subd[keys[-1]] = value
573    return d
TABLE_ARGS = ('this', 'db', 'catalog')
class Schema(abc.ABC):
 24class Schema(abc.ABC):
 25    """Abstract base class for database schemas"""
 26
 27    dialect: DialectType
 28
 29    @abc.abstractmethod
 30    def add_table(
 31        self,
 32        table: exp.Table | str,
 33        column_mapping: t.Optional[ColumnMapping] = None,
 34        dialect: DialectType = None,
 35        normalize: t.Optional[bool] = None,
 36        match_depth: bool = True,
 37    ) -> None:
 38        """
 39        Register or update a table. Some implementing classes may require column information to also be provided.
 40        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 41
 42        Args:
 43            table: the `Table` expression instance or string representing the table.
 44            column_mapping: a column mapping that describes the structure of the table.
 45            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 46            normalize: whether to normalize identifiers according to the dialect of interest.
 47            match_depth: whether to enforce that the table must match the schema's depth or not.
 48        """
 49
 50    @abc.abstractmethod
 51    def column_names(
 52        self,
 53        table: exp.Table | str,
 54        only_visible: bool = False,
 55        dialect: DialectType = None,
 56        normalize: t.Optional[bool] = None,
 57    ) -> t.List[str]:
 58        """
 59        Get the column names for a table.
 60
 61        Args:
 62            table: the `Table` expression instance.
 63            only_visible: whether to include invisible columns.
 64            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 65            normalize: whether to normalize identifiers according to the dialect of interest.
 66
 67        Returns:
 68            The list of column names.
 69        """
 70
 71    @abc.abstractmethod
 72    def get_column_type(
 73        self,
 74        table: exp.Table | str,
 75        column: exp.Column | str,
 76        dialect: DialectType = None,
 77        normalize: t.Optional[bool] = None,
 78    ) -> exp.DataType:
 79        """
 80        Get the `sqlglot.exp.DataType` type of a column in the schema.
 81
 82        Args:
 83            table: the source table.
 84            column: the target column.
 85            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 86            normalize: whether to normalize identifiers according to the dialect of interest.
 87
 88        Returns:
 89            The resulting column type.
 90        """
 91
 92    def has_column(
 93        self,
 94        table: exp.Table | str,
 95        column: exp.Column | str,
 96        dialect: DialectType = None,
 97        normalize: t.Optional[bool] = None,
 98    ) -> bool:
 99        """
100        Returns whether or not `column` appears in `table`'s schema.
101
102        Args:
103            table: the source table.
104            column: the target column.
105            dialect: the SQL dialect that will be used to parse `table` if it's a string.
106            normalize: whether to normalize identifiers according to the dialect of interest.
107
108        Returns:
109            True if the column appears in the schema, False otherwise.
110        """
111        name = column if isinstance(column, str) else column.name
112        return name in self.column_names(table, dialect=dialect, normalize=normalize)
113
114    @property
115    @abc.abstractmethod
116    def supported_table_args(self) -> t.Tuple[str, ...]:
117        """
118        Table arguments this schema support, e.g. `("this", "db", "catalog")`
119        """
120
121    @property
122    def empty(self) -> bool:
123        """Returns whether or not the schema is empty."""
124        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35        normalize: t.Optional[bool] = None,
36        match_depth: bool = True,
37    ) -> None:
38        """
39        Register or update a table. Some implementing classes may require column information to also be provided.
40        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
41
42        Args:
43            table: the `Table` expression instance or string representing the table.
44            column_mapping: a column mapping that describes the structure of the table.
45            dialect: the SQL dialect that will be used to parse `table` if it's a string.
46            normalize: whether to normalize identifiers according to the dialect of interest.
47            match_depth: whether to enforce that the table must match the schema's depth or not.
48        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
50    @abc.abstractmethod
51    def column_names(
52        self,
53        table: exp.Table | str,
54        only_visible: bool = False,
55        dialect: DialectType = None,
56        normalize: t.Optional[bool] = None,
57    ) -> t.List[str]:
58        """
59        Get the column names for a table.
60
61        Args:
62            table: the `Table` expression instance.
63            only_visible: whether to include invisible columns.
64            dialect: the SQL dialect that will be used to parse `table` if it's a string.
65            normalize: whether to normalize identifiers according to the dialect of interest.
66
67        Returns:
68            The list of column names.
69        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
71    @abc.abstractmethod
72    def get_column_type(
73        self,
74        table: exp.Table | str,
75        column: exp.Column | str,
76        dialect: DialectType = None,
77        normalize: t.Optional[bool] = None,
78    ) -> exp.DataType:
79        """
80        Get the `sqlglot.exp.DataType` type of a column in the schema.
81
82        Args:
83            table: the source table.
84            column: the target column.
85            dialect: the SQL dialect that will be used to parse `table` if it's a string.
86            normalize: whether to normalize identifiers according to the dialect of interest.
87
88        Returns:
89            The resulting column type.
90        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 92    def has_column(
 93        self,
 94        table: exp.Table | str,
 95        column: exp.Column | str,
 96        dialect: DialectType = None,
 97        normalize: t.Optional[bool] = None,
 98    ) -> bool:
 99        """
100        Returns whether or not `column` appears in `table`'s schema.
101
102        Args:
103            table: the source table.
104            column: the target column.
105            dialect: the SQL dialect that will be used to parse `table` if it's a string.
106            normalize: whether to normalize identifiers according to the dialect of interest.
107
108        Returns:
109            True if the column appears in the schema, False otherwise.
110        """
111        name = column if isinstance(column, str) else column.name
112        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
127class AbstractMappingSchema(t.Generic[T]):
128    def __init__(
129        self,
130        mapping: t.Optional[t.Dict] = None,
131    ) -> None:
132        self.mapping = mapping or {}
133        self.mapping_trie = new_trie(
134            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
135        )
136        self._supported_table_args: t.Tuple[str, ...] = tuple()
137
138    @property
139    def empty(self) -> bool:
140        return not self.mapping
141
142    def depth(self) -> int:
143        return dict_depth(self.mapping)
144
145    @property
146    def supported_table_args(self) -> t.Tuple[str, ...]:
147        if not self._supported_table_args and self.mapping:
148            depth = self.depth()
149
150            if not depth:  # None
151                self._supported_table_args = tuple()
152            elif 1 <= depth <= 3:
153                self._supported_table_args = TABLE_ARGS[:depth]
154            else:
155                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
156
157        return self._supported_table_args
158
159    def table_parts(self, table: exp.Table) -> t.List[str]:
160        if isinstance(table.this, exp.ReadCSV):
161            return [table.this.name]
162        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
163
164    def find(
165        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
166    ) -> t.Optional[T]:
167        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
168        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
169
170        if value == TrieResult.FAILED:
171            return None
172
173        if value == TrieResult.PREFIX:
174            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
175
176            if len(possibilities) == 1:
177                parts.extend(possibilities[0])
178            else:
179                message = ", ".join(".".join(parts) for parts in possibilities)
180                if raise_on_missing:
181                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
182                return None
183
184        return self.nested_get(parts, raise_on_missing=raise_on_missing)
185
186    def nested_get(
187        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
188    ) -> t.Optional[t.Any]:
189        return nested_get(
190            d or self.mapping,
191            *zip(self.supported_table_args, reversed(parts)),
192            raise_on_missing=raise_on_missing,
193        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
128    def __init__(
129        self,
130        mapping: t.Optional[t.Dict] = None,
131    ) -> None:
132        self.mapping = mapping or {}
133        self.mapping_trie = new_trie(
134            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
135        )
136        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
def depth(self) -> int:
142    def depth(self) -> int:
143        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
159    def table_parts(self, table: exp.Table) -> t.List[str]:
160        if isinstance(table.this, exp.ReadCSV):
161            return [table.this.name]
162        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
164    def find(
165        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
166    ) -> t.Optional[T]:
167        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
168        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
169
170        if value == TrieResult.FAILED:
171            return None
172
173        if value == TrieResult.PREFIX:
174            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
175
176            if len(possibilities) == 1:
177                parts.extend(possibilities[0])
178            else:
179                message = ", ".join(".".join(parts) for parts in possibilities)
180                if raise_on_missing:
181                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
182                return None
183
184        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
186    def nested_get(
187        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
188    ) -> t.Optional[t.Any]:
189        return nested_get(
190            d or self.mapping,
191            *zip(self.supported_table_args, reversed(parts)),
192            raise_on_missing=raise_on_missing,
193        )
class MappingSchema(typing.Generic[~T]):
196class MappingSchema(AbstractMappingSchema, Schema):
197    """
198    Schema based on a nested mapping.
199
200    Args:
201        schema: Mapping in one of the following forms:
202            1. {table: {col: type}}
203            2. {db: {table: {col: type}}}
204            3. {catalog: {db: {table: {col: type}}}}
205            4. None - Tables will be added later
206        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
207            are assumed to be visible. The nesting should mirror that of the schema:
208            1. {table: set(*cols)}}
209            2. {db: {table: set(*cols)}}}
210            3. {catalog: {db: {table: set(*cols)}}}}
211        dialect: The dialect to be used for custom type mappings & parsing string arguments.
212        normalize: Whether to normalize identifier names according to the given dialect or not.
213    """
214
215    def __init__(
216        self,
217        schema: t.Optional[t.Dict] = None,
218        visible: t.Optional[t.Dict] = None,
219        dialect: DialectType = None,
220        normalize: bool = True,
221    ) -> None:
222        self.dialect = dialect
223        self.visible = visible or {}
224        self.normalize = normalize
225        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
226        self._depth = 0
227
228        super().__init__(self._normalize(schema or {}))
229
230    @classmethod
231    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
232        return MappingSchema(
233            schema=mapping_schema.mapping,
234            visible=mapping_schema.visible,
235            dialect=mapping_schema.dialect,
236            normalize=mapping_schema.normalize,
237        )
238
239    def copy(self, **kwargs) -> MappingSchema:
240        return MappingSchema(
241            **{  # type: ignore
242                "schema": self.mapping.copy(),
243                "visible": self.visible.copy(),
244                "dialect": self.dialect,
245                "normalize": self.normalize,
246                **kwargs,
247            }
248        )
249
250    def add_table(
251        self,
252        table: exp.Table | str,
253        column_mapping: t.Optional[ColumnMapping] = None,
254        dialect: DialectType = None,
255        normalize: t.Optional[bool] = None,
256        match_depth: bool = True,
257    ) -> None:
258        """
259        Register or update a table. Updates are only performed if a new column mapping is provided.
260        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
261
262        Args:
263            table: the `Table` expression instance or string representing the table.
264            column_mapping: a column mapping that describes the structure of the table.
265            dialect: the SQL dialect that will be used to parse `table` if it's a string.
266            normalize: whether to normalize identifiers according to the dialect of interest.
267            match_depth: whether to enforce that the table must match the schema's depth or not.
268        """
269        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
270
271        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
272            raise SchemaError(
273                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
274                f"schema's nesting level: {self.depth()}."
275            )
276
277        normalized_column_mapping = {
278            self._normalize_name(key, dialect=dialect, normalize=normalize): value
279            for key, value in ensure_column_mapping(column_mapping).items()
280        }
281
282        schema = self.find(normalized_table, raise_on_missing=False)
283        if schema and not normalized_column_mapping:
284            return
285
286        parts = self.table_parts(normalized_table)
287
288        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
289        new_trie([parts], self.mapping_trie)
290
291    def column_names(
292        self,
293        table: exp.Table | str,
294        only_visible: bool = False,
295        dialect: DialectType = None,
296        normalize: t.Optional[bool] = None,
297    ) -> t.List[str]:
298        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
299
300        schema = self.find(normalized_table)
301        if schema is None:
302            return []
303
304        if not only_visible or not self.visible:
305            return list(schema)
306
307        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
308        return [col for col in schema if col in visible]
309
310    def get_column_type(
311        self,
312        table: exp.Table | str,
313        column: exp.Column | str,
314        dialect: DialectType = None,
315        normalize: t.Optional[bool] = None,
316    ) -> exp.DataType:
317        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
318
319        normalized_column_name = self._normalize_name(
320            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
321        )
322
323        table_schema = self.find(normalized_table, raise_on_missing=False)
324        if table_schema:
325            column_type = table_schema.get(normalized_column_name)
326
327            if isinstance(column_type, exp.DataType):
328                return column_type
329            elif isinstance(column_type, str):
330                return self._to_data_type(column_type, dialect=dialect)
331
332        return exp.DataType.build("unknown")
333
334    def has_column(
335        self,
336        table: exp.Table | str,
337        column: exp.Column | str,
338        dialect: DialectType = None,
339        normalize: t.Optional[bool] = None,
340    ) -> bool:
341        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
342
343        normalized_column_name = self._normalize_name(
344            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
345        )
346
347        table_schema = self.find(normalized_table, raise_on_missing=False)
348        return normalized_column_name in table_schema if table_schema else False
349
350    def _normalize(self, schema: t.Dict) -> t.Dict:
351        """
352        Normalizes all identifiers in the schema.
353
354        Args:
355            schema: the schema to normalize.
356
357        Returns:
358            The normalized schema mapping.
359        """
360        normalized_mapping: t.Dict = {}
361        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
362
363        for keys in flattened_schema:
364            columns = nested_get(schema, *zip(keys, keys))
365
366            if not isinstance(columns, dict):
367                raise SchemaError(
368                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
369                )
370
371            normalized_keys = [
372                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
373            ]
374            for column_name, column_type in columns.items():
375                nested_set(
376                    normalized_mapping,
377                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
378                    column_type,
379                )
380
381        return normalized_mapping
382
383    def _normalize_table(
384        self,
385        table: exp.Table | str,
386        dialect: DialectType = None,
387        normalize: t.Optional[bool] = None,
388    ) -> exp.Table:
389        normalized_table = exp.maybe_parse(
390            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
391        )
392
393        for arg in TABLE_ARGS:
394            value = normalized_table.args.get(arg)
395            if isinstance(value, (str, exp.Identifier)):
396                normalized_table.set(
397                    arg,
398                    exp.to_identifier(
399                        self._normalize_name(
400                            value, dialect=dialect, is_table=True, normalize=normalize
401                        )
402                    ),
403                )
404
405        return normalized_table
406
407    def _normalize_name(
408        self,
409        name: str | exp.Identifier,
410        dialect: DialectType = None,
411        is_table: bool = False,
412        normalize: t.Optional[bool] = None,
413    ) -> str:
414        return normalize_name(
415            name,
416            dialect=dialect or self.dialect,
417            is_table=is_table,
418            normalize=self.normalize if normalize is None else normalize,
419        )
420
421    def depth(self) -> int:
422        if not self.empty and not self._depth:
423            # The columns themselves are a mapping, but we don't want to include those
424            self._depth = super().depth() - 1
425        return self._depth
426
427    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
428        """
429        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
430
431        Args:
432            schema_type: the type we want to convert.
433            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
434
435        Returns:
436            The resulting expression type.
437        """
438        if schema_type not in self._type_mapping_cache:
439            dialect = dialect or self.dialect
440            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
441
442            try:
443                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
444                self._type_mapping_cache[schema_type] = expression
445            except AttributeError:
446                in_dialect = f" in dialect {dialect}" if dialect else ""
447                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
448
449        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
215    def __init__(
216        self,
217        schema: t.Optional[t.Dict] = None,
218        visible: t.Optional[t.Dict] = None,
219        dialect: DialectType = None,
220        normalize: bool = True,
221    ) -> None:
222        self.dialect = dialect
223        self.visible = visible or {}
224        self.normalize = normalize
225        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
226        self._depth = 0
227
228        super().__init__(self._normalize(schema or {}))
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
230    @classmethod
231    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
232        return MappingSchema(
233            schema=mapping_schema.mapping,
234            visible=mapping_schema.visible,
235            dialect=mapping_schema.dialect,
236            normalize=mapping_schema.normalize,
237        )
def copy(self, **kwargs) -> MappingSchema:
239    def copy(self, **kwargs) -> MappingSchema:
240        return MappingSchema(
241            **{  # type: ignore
242                "schema": self.mapping.copy(),
243                "visible": self.visible.copy(),
244                "dialect": self.dialect,
245                "normalize": self.normalize,
246                **kwargs,
247            }
248        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
250    def add_table(
251        self,
252        table: exp.Table | str,
253        column_mapping: t.Optional[ColumnMapping] = None,
254        dialect: DialectType = None,
255        normalize: t.Optional[bool] = None,
256        match_depth: bool = True,
257    ) -> None:
258        """
259        Register or update a table. Updates are only performed if a new column mapping is provided.
260        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
261
262        Args:
263            table: the `Table` expression instance or string representing the table.
264            column_mapping: a column mapping that describes the structure of the table.
265            dialect: the SQL dialect that will be used to parse `table` if it's a string.
266            normalize: whether to normalize identifiers according to the dialect of interest.
267            match_depth: whether to enforce that the table must match the schema's depth or not.
268        """
269        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
270
271        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
272            raise SchemaError(
273                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
274                f"schema's nesting level: {self.depth()}."
275            )
276
277        normalized_column_mapping = {
278            self._normalize_name(key, dialect=dialect, normalize=normalize): value
279            for key, value in ensure_column_mapping(column_mapping).items()
280        }
281
282        schema = self.find(normalized_table, raise_on_missing=False)
283        if schema and not normalized_column_mapping:
284            return
285
286        parts = self.table_parts(normalized_table)
287
288        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
289        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
291    def column_names(
292        self,
293        table: exp.Table | str,
294        only_visible: bool = False,
295        dialect: DialectType = None,
296        normalize: t.Optional[bool] = None,
297    ) -> t.List[str]:
298        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
299
300        schema = self.find(normalized_table)
301        if schema is None:
302            return []
303
304        if not only_visible or not self.visible:
305            return list(schema)
306
307        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
308        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
310    def get_column_type(
311        self,
312        table: exp.Table | str,
313        column: exp.Column | str,
314        dialect: DialectType = None,
315        normalize: t.Optional[bool] = None,
316    ) -> exp.DataType:
317        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
318
319        normalized_column_name = self._normalize_name(
320            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
321        )
322
323        table_schema = self.find(normalized_table, raise_on_missing=False)
324        if table_schema:
325            column_type = table_schema.get(normalized_column_name)
326
327            if isinstance(column_type, exp.DataType):
328                return column_type
329            elif isinstance(column_type, str):
330                return self._to_data_type(column_type, dialect=dialect)
331
332        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
334    def has_column(
335        self,
336        table: exp.Table | str,
337        column: exp.Column | str,
338        dialect: DialectType = None,
339        normalize: t.Optional[bool] = None,
340    ) -> bool:
341        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
342
343        normalized_column_name = self._normalize_name(
344            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
345        )
346
347        table_schema = self.find(normalized_table, raise_on_missing=False)
348        return normalized_column_name in table_schema if table_schema else False

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
421    def depth(self) -> int:
422        if not self.empty and not self._depth:
423            # The columns themselves are a mapping, but we don't want to include those
424            self._depth = super().depth() - 1
425        return self._depth
def normalize_name( name: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> str:
452def normalize_name(
453    name: str | exp.Identifier,
454    dialect: DialectType = None,
455    is_table: bool = False,
456    normalize: t.Optional[bool] = True,
457) -> str:
458    try:
459        identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
460    except ParseError:
461        return name if isinstance(name, str) else name.name
462
463    name = identifier.name
464    if not normalize:
465        return name
466
467    # This can be useful for normalize_identifier
468    identifier.meta["is_table"] = is_table
469    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
472def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
473    if isinstance(schema, Schema):
474        return schema
475
476    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
479def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
480    if mapping is None:
481        return {}
482    elif isinstance(mapping, dict):
483        return mapping
484    elif isinstance(mapping, str):
485        col_name_type_strs = [x.strip() for x in mapping.split(",")]
486        return {
487            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
488            for name_type_str in col_name_type_strs
489        }
490    # Check if mapping looks like a DataFrame StructType
491    elif hasattr(mapping, "simpleString"):
492        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
493    elif isinstance(mapping, list):
494        return {x.strip(): None for x in mapping}
495
496    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
499def flatten_schema(
500    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
501) -> t.List[t.List[str]]:
502    tables = []
503    keys = keys or []
504
505    for k, v in schema.items():
506        if depth >= 2:
507            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
508        elif depth == 1:
509            tables.append(keys + [k])
510
511    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
514def nested_get(
515    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
516) -> t.Optional[t.Any]:
517    """
518    Get a value for a nested dictionary.
519
520    Args:
521        d: the dictionary to search.
522        *path: tuples of (name, key), where:
523            `key` is the key in the dictionary to get.
524            `name` is a string to use in the error if `key` isn't found.
525
526    Returns:
527        The value or None if it doesn't exist.
528    """
529    for name, key in path:
530        d = d.get(key)  # type: ignore
531        if d is None:
532            if raise_on_missing:
533                name = "table" if name == "this" else name
534                raise ValueError(f"Unknown {name}: {key}")
535            return None
536
537    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
540def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
541    """
542    In-place set a value for a nested dictionary
543
544    Example:
545        >>> nested_set({}, ["top_key", "second_key"], "value")
546        {'top_key': {'second_key': 'value'}}
547
548        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
549        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
550
551    Args:
552        d: dictionary to update.
553        keys: the keys that makeup the path to `value`.
554        value: the value to set in the dictionary for the given key path.
555
556    Returns:
557        The (possibly) updated dictionary.
558    """
559    if not keys:
560        return d
561
562    if len(keys) == 1:
563        d[keys[0]] = value
564        return d
565
566    subd = d
567    for key in keys[:-1]:
568        if key not in subd:
569            subd = subd.setdefault(key, {})
570        else:
571            subd = subd[key]
572
573    subd[keys[-1]] = value
574    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.