Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot.dialects.dialect import Dialect
  9from sqlglot.errors import ParseError, SchemaError
 10from sqlglot.helper import dict_depth
 11from sqlglot.trie import TrieResult, in_trie, new_trie
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dataframe.sql.types import StructType
 15    from sqlglot.dialects.dialect import DialectType
 16
 17    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 18
 19TABLE_ARGS = ("this", "db", "catalog")
 20
 21
 22class Schema(abc.ABC):
 23    """Abstract base class for database schemas"""
 24
 25    dialect: DialectType
 26
 27    @abc.abstractmethod
 28    def add_table(
 29        self,
 30        table: exp.Table | str,
 31        column_mapping: t.Optional[ColumnMapping] = None,
 32        dialect: DialectType = None,
 33        normalize: t.Optional[bool] = None,
 34        match_depth: bool = True,
 35    ) -> None:
 36        """
 37        Register or update a table. Some implementing classes may require column information to also be provided.
 38        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 39
 40        Args:
 41            table: the `Table` expression instance or string representing the table.
 42            column_mapping: a column mapping that describes the structure of the table.
 43            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 44            normalize: whether to normalize identifiers according to the dialect of interest.
 45            match_depth: whether to enforce that the table must match the schema's depth or not.
 46        """
 47
 48    @abc.abstractmethod
 49    def column_names(
 50        self,
 51        table: exp.Table | str,
 52        only_visible: bool = False,
 53        dialect: DialectType = None,
 54        normalize: t.Optional[bool] = None,
 55    ) -> t.List[str]:
 56        """
 57        Get the column names for a table.
 58
 59        Args:
 60            table: the `Table` expression instance.
 61            only_visible: whether to include invisible columns.
 62            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 63            normalize: whether to normalize identifiers according to the dialect of interest.
 64
 65        Returns:
 66            The list of column names.
 67        """
 68
 69    @abc.abstractmethod
 70    def get_column_type(
 71        self,
 72        table: exp.Table | str,
 73        column: exp.Column | str,
 74        dialect: DialectType = None,
 75        normalize: t.Optional[bool] = None,
 76    ) -> exp.DataType:
 77        """
 78        Get the `sqlglot.exp.DataType` type of a column in the schema.
 79
 80        Args:
 81            table: the source table.
 82            column: the target column.
 83            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 84            normalize: whether to normalize identifiers according to the dialect of interest.
 85
 86        Returns:
 87            The resulting column type.
 88        """
 89
 90    def has_column(
 91        self,
 92        table: exp.Table | str,
 93        column: exp.Column | str,
 94        dialect: DialectType = None,
 95        normalize: t.Optional[bool] = None,
 96    ) -> bool:
 97        """
 98        Returns whether or not `column` appears in `table`'s schema.
 99
100        Args:
101            table: the source table.
102            column: the target column.
103            dialect: the SQL dialect that will be used to parse `table` if it's a string.
104            normalize: whether to normalize identifiers according to the dialect of interest.
105
106        Returns:
107            True if the column appears in the schema, False otherwise.
108        """
109        name = column if isinstance(column, str) else column.name
110        return name in self.column_names(table, dialect=dialect, normalize=normalize)
111
112    @property
113    @abc.abstractmethod
114    def supported_table_args(self) -> t.Tuple[str, ...]:
115        """
116        Table arguments this schema support, e.g. `("this", "db", "catalog")`
117        """
118
119    @property
120    def empty(self) -> bool:
121        """Returns whether or not the schema is empty."""
122        return True
123
124
125class AbstractMappingSchema:
126    def __init__(
127        self,
128        mapping: t.Optional[t.Dict] = None,
129    ) -> None:
130        self.mapping = mapping or {}
131        self.mapping_trie = new_trie(
132            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
133        )
134        self._supported_table_args: t.Tuple[str, ...] = tuple()
135
136    @property
137    def empty(self) -> bool:
138        return not self.mapping
139
140    def depth(self) -> int:
141        return dict_depth(self.mapping)
142
143    @property
144    def supported_table_args(self) -> t.Tuple[str, ...]:
145        if not self._supported_table_args and self.mapping:
146            depth = self.depth()
147
148            if not depth:  # None
149                self._supported_table_args = tuple()
150            elif 1 <= depth <= 3:
151                self._supported_table_args = TABLE_ARGS[:depth]
152            else:
153                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
154
155        return self._supported_table_args
156
157    def table_parts(self, table: exp.Table) -> t.List[str]:
158        if isinstance(table.this, exp.ReadCSV):
159            return [table.this.name]
160        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
161
162    def find(
163        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
164    ) -> t.Optional[t.Any]:
165        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
166        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
167
168        if value == TrieResult.FAILED:
169            return None
170
171        if value == TrieResult.PREFIX:
172            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
173
174            if len(possibilities) == 1:
175                parts.extend(possibilities[0])
176            else:
177                message = ", ".join(".".join(parts) for parts in possibilities)
178                if raise_on_missing:
179                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
180                return None
181
182        return self.nested_get(parts, raise_on_missing=raise_on_missing)
183
184    def nested_get(
185        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
186    ) -> t.Optional[t.Any]:
187        return nested_get(
188            d or self.mapping,
189            *zip(self.supported_table_args, reversed(parts)),
190            raise_on_missing=raise_on_missing,
191        )
192
193
194class MappingSchema(AbstractMappingSchema, Schema):
195    """
196    Schema based on a nested mapping.
197
198    Args:
199        schema: Mapping in one of the following forms:
200            1. {table: {col: type}}
201            2. {db: {table: {col: type}}}
202            3. {catalog: {db: {table: {col: type}}}}
203            4. None - Tables will be added later
204        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
205            are assumed to be visible. The nesting should mirror that of the schema:
206            1. {table: set(*cols)}}
207            2. {db: {table: set(*cols)}}}
208            3. {catalog: {db: {table: set(*cols)}}}}
209        dialect: The dialect to be used for custom type mappings & parsing string arguments.
210        normalize: Whether to normalize identifier names according to the given dialect or not.
211    """
212
213    def __init__(
214        self,
215        schema: t.Optional[t.Dict] = None,
216        visible: t.Optional[t.Dict] = None,
217        dialect: DialectType = None,
218        normalize: bool = True,
219    ) -> None:
220        self.dialect = dialect
221        self.visible = visible or {}
222        self.normalize = normalize
223        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
224        self._depth = 0
225
226        super().__init__(self._normalize(schema or {}))
227
228    @classmethod
229    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
230        return MappingSchema(
231            schema=mapping_schema.mapping,
232            visible=mapping_schema.visible,
233            dialect=mapping_schema.dialect,
234            normalize=mapping_schema.normalize,
235        )
236
237    def copy(self, **kwargs) -> MappingSchema:
238        return MappingSchema(
239            **{  # type: ignore
240                "schema": self.mapping.copy(),
241                "visible": self.visible.copy(),
242                "dialect": self.dialect,
243                "normalize": self.normalize,
244                **kwargs,
245            }
246        )
247
248    def add_table(
249        self,
250        table: exp.Table | str,
251        column_mapping: t.Optional[ColumnMapping] = None,
252        dialect: DialectType = None,
253        normalize: t.Optional[bool] = None,
254        match_depth: bool = True,
255    ) -> None:
256        """
257        Register or update a table. Updates are only performed if a new column mapping is provided.
258        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
259
260        Args:
261            table: the `Table` expression instance or string representing the table.
262            column_mapping: a column mapping that describes the structure of the table.
263            dialect: the SQL dialect that will be used to parse `table` if it's a string.
264            normalize: whether to normalize identifiers according to the dialect of interest.
265            match_depth: whether to enforce that the table must match the schema's depth or not.
266        """
267        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
268
269        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
270            raise SchemaError(
271                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
272                f"schema's nesting level: {self.depth()}."
273            )
274
275        normalized_column_mapping = {
276            self._normalize_name(key, dialect=dialect, normalize=normalize): value
277            for key, value in ensure_column_mapping(column_mapping).items()
278        }
279
280        schema = self.find(normalized_table, raise_on_missing=False)
281        if schema and not normalized_column_mapping:
282            return
283
284        parts = self.table_parts(normalized_table)
285
286        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
287        new_trie([parts], self.mapping_trie)
288
289    def column_names(
290        self,
291        table: exp.Table | str,
292        only_visible: bool = False,
293        dialect: DialectType = None,
294        normalize: t.Optional[bool] = None,
295    ) -> t.List[str]:
296        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
297
298        schema = self.find(normalized_table)
299        if schema is None:
300            return []
301
302        if not only_visible or not self.visible:
303            return list(schema)
304
305        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
306        return [col for col in schema if col in visible]
307
308    def get_column_type(
309        self,
310        table: exp.Table | str,
311        column: exp.Column | str,
312        dialect: DialectType = None,
313        normalize: t.Optional[bool] = None,
314    ) -> exp.DataType:
315        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
316
317        normalized_column_name = self._normalize_name(
318            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
319        )
320
321        table_schema = self.find(normalized_table, raise_on_missing=False)
322        if table_schema:
323            column_type = table_schema.get(normalized_column_name)
324
325            if isinstance(column_type, exp.DataType):
326                return column_type
327            elif isinstance(column_type, str):
328                return self._to_data_type(column_type, dialect=dialect)
329
330        return exp.DataType.build("unknown")
331
332    def has_column(
333        self,
334        table: exp.Table | str,
335        column: exp.Column | str,
336        dialect: DialectType = None,
337        normalize: t.Optional[bool] = None,
338    ) -> bool:
339        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
340
341        normalized_column_name = self._normalize_name(
342            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
343        )
344
345        table_schema = self.find(normalized_table, raise_on_missing=False)
346        return normalized_column_name in table_schema if table_schema else False
347
348    def _normalize(self, schema: t.Dict) -> t.Dict:
349        """
350        Normalizes all identifiers in the schema.
351
352        Args:
353            schema: the schema to normalize.
354
355        Returns:
356            The normalized schema mapping.
357        """
358        normalized_mapping: t.Dict = {}
359        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
360
361        for keys in flattened_schema:
362            columns = nested_get(schema, *zip(keys, keys))
363
364            if not isinstance(columns, dict):
365                raise SchemaError(
366                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
367                )
368
369            normalized_keys = [
370                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
371            ]
372            for column_name, column_type in columns.items():
373                nested_set(
374                    normalized_mapping,
375                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
376                    column_type,
377                )
378
379        return normalized_mapping
380
381    def _normalize_table(
382        self,
383        table: exp.Table | str,
384        dialect: DialectType = None,
385        normalize: t.Optional[bool] = None,
386    ) -> exp.Table:
387        normalized_table = exp.maybe_parse(
388            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
389        )
390
391        for arg in TABLE_ARGS:
392            value = normalized_table.args.get(arg)
393            if isinstance(value, (str, exp.Identifier)):
394                normalized_table.set(
395                    arg,
396                    exp.to_identifier(
397                        self._normalize_name(
398                            value, dialect=dialect, is_table=True, normalize=normalize
399                        )
400                    ),
401                )
402
403        return normalized_table
404
405    def _normalize_name(
406        self,
407        name: str | exp.Identifier,
408        dialect: DialectType = None,
409        is_table: bool = False,
410        normalize: t.Optional[bool] = None,
411    ) -> str:
412        return normalize_name(
413            name,
414            dialect=dialect or self.dialect,
415            is_table=is_table,
416            normalize=self.normalize if normalize is None else normalize,
417        )
418
419    def depth(self) -> int:
420        if not self.empty and not self._depth:
421            # The columns themselves are a mapping, but we don't want to include those
422            self._depth = super().depth() - 1
423        return self._depth
424
425    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
426        """
427        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
428
429        Args:
430            schema_type: the type we want to convert.
431            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
432
433        Returns:
434            The resulting expression type.
435        """
436        if schema_type not in self._type_mapping_cache:
437            dialect = dialect or self.dialect
438            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
439
440            try:
441                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
442                self._type_mapping_cache[schema_type] = expression
443            except AttributeError:
444                in_dialect = f" in dialect {dialect}" if dialect else ""
445                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
446
447        return self._type_mapping_cache[schema_type]
448
449
450def normalize_name(
451    name: str | exp.Identifier,
452    dialect: DialectType = None,
453    is_table: bool = False,
454    normalize: t.Optional[bool] = True,
455) -> str:
456    try:
457        identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
458    except ParseError:
459        return name if isinstance(name, str) else name.name
460
461    name = identifier.name
462    if not normalize:
463        return name
464
465    # This can be useful for normalize_identifier
466    identifier.meta["is_table"] = is_table
467    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
468
469
470def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
471    if isinstance(schema, Schema):
472        return schema
473
474    return MappingSchema(schema, **kwargs)
475
476
477def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
478    if mapping is None:
479        return {}
480    elif isinstance(mapping, dict):
481        return mapping
482    elif isinstance(mapping, str):
483        col_name_type_strs = [x.strip() for x in mapping.split(",")]
484        return {
485            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
486            for name_type_str in col_name_type_strs
487        }
488    # Check if mapping looks like a DataFrame StructType
489    elif hasattr(mapping, "simpleString"):
490        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
491    elif isinstance(mapping, list):
492        return {x.strip(): None for x in mapping}
493
494    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
495
496
497def flatten_schema(
498    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
499) -> t.List[t.List[str]]:
500    tables = []
501    keys = keys or []
502
503    for k, v in schema.items():
504        if depth >= 2:
505            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
506        elif depth == 1:
507            tables.append(keys + [k])
508
509    return tables
510
511
512def nested_get(
513    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
514) -> t.Optional[t.Any]:
515    """
516    Get a value for a nested dictionary.
517
518    Args:
519        d: the dictionary to search.
520        *path: tuples of (name, key), where:
521            `key` is the key in the dictionary to get.
522            `name` is a string to use in the error if `key` isn't found.
523
524    Returns:
525        The value or None if it doesn't exist.
526    """
527    for name, key in path:
528        d = d.get(key)  # type: ignore
529        if d is None:
530            if raise_on_missing:
531                name = "table" if name == "this" else name
532                raise ValueError(f"Unknown {name}: {key}")
533            return None
534
535    return d
536
537
538def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
539    """
540    In-place set a value for a nested dictionary
541
542    Example:
543        >>> nested_set({}, ["top_key", "second_key"], "value")
544        {'top_key': {'second_key': 'value'}}
545
546        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
547        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
548
549    Args:
550        d: dictionary to update.
551        keys: the keys that makeup the path to `value`.
552        value: the value to set in the dictionary for the given key path.
553
554    Returns:
555        The (possibly) updated dictionary.
556    """
557    if not keys:
558        return d
559
560    if len(keys) == 1:
561        d[keys[0]] = value
562        return d
563
564    subd = d
565    for key in keys[:-1]:
566        if key not in subd:
567            subd = subd.setdefault(key, {})
568        else:
569            subd = subd[key]
570
571    subd[keys[-1]] = value
572    return d
TABLE_ARGS = ('this', 'db', 'catalog')
class Schema(abc.ABC):
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34        normalize: t.Optional[bool] = None,
 35        match_depth: bool = True,
 36    ) -> None:
 37        """
 38        Register or update a table. Some implementing classes may require column information to also be provided.
 39        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 40
 41        Args:
 42            table: the `Table` expression instance or string representing the table.
 43            column_mapping: a column mapping that describes the structure of the table.
 44            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 45            normalize: whether to normalize identifiers according to the dialect of interest.
 46            match_depth: whether to enforce that the table must match the schema's depth or not.
 47        """
 48
 49    @abc.abstractmethod
 50    def column_names(
 51        self,
 52        table: exp.Table | str,
 53        only_visible: bool = False,
 54        dialect: DialectType = None,
 55        normalize: t.Optional[bool] = None,
 56    ) -> t.List[str]:
 57        """
 58        Get the column names for a table.
 59
 60        Args:
 61            table: the `Table` expression instance.
 62            only_visible: whether to include invisible columns.
 63            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 64            normalize: whether to normalize identifiers according to the dialect of interest.
 65
 66        Returns:
 67            The list of column names.
 68        """
 69
 70    @abc.abstractmethod
 71    def get_column_type(
 72        self,
 73        table: exp.Table | str,
 74        column: exp.Column | str,
 75        dialect: DialectType = None,
 76        normalize: t.Optional[bool] = None,
 77    ) -> exp.DataType:
 78        """
 79        Get the `sqlglot.exp.DataType` type of a column in the schema.
 80
 81        Args:
 82            table: the source table.
 83            column: the target column.
 84            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 85            normalize: whether to normalize identifiers according to the dialect of interest.
 86
 87        Returns:
 88            The resulting column type.
 89        """
 90
 91    def has_column(
 92        self,
 93        table: exp.Table | str,
 94        column: exp.Column | str,
 95        dialect: DialectType = None,
 96        normalize: t.Optional[bool] = None,
 97    ) -> bool:
 98        """
 99        Returns whether or not `column` appears in `table`'s schema.
100
101        Args:
102            table: the source table.
103            column: the target column.
104            dialect: the SQL dialect that will be used to parse `table` if it's a string.
105            normalize: whether to normalize identifiers according to the dialect of interest.
106
107        Returns:
108            True if the column appears in the schema, False otherwise.
109        """
110        name = column if isinstance(column, str) else column.name
111        return name in self.column_names(table, dialect=dialect, normalize=normalize)
112
113    @property
114    @abc.abstractmethod
115    def supported_table_args(self) -> t.Tuple[str, ...]:
116        """
117        Table arguments this schema support, e.g. `("this", "db", "catalog")`
118        """
119
120    @property
121    def empty(self) -> bool:
122        """Returns whether or not the schema is empty."""
123        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
28    @abc.abstractmethod
29    def add_table(
30        self,
31        table: exp.Table | str,
32        column_mapping: t.Optional[ColumnMapping] = None,
33        dialect: DialectType = None,
34        normalize: t.Optional[bool] = None,
35        match_depth: bool = True,
36    ) -> None:
37        """
38        Register or update a table. Some implementing classes may require column information to also be provided.
39        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
40
41        Args:
42            table: the `Table` expression instance or string representing the table.
43            column_mapping: a column mapping that describes the structure of the table.
44            dialect: the SQL dialect that will be used to parse `table` if it's a string.
45            normalize: whether to normalize identifiers according to the dialect of interest.
46            match_depth: whether to enforce that the table must match the schema's depth or not.
47        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
49    @abc.abstractmethod
50    def column_names(
51        self,
52        table: exp.Table | str,
53        only_visible: bool = False,
54        dialect: DialectType = None,
55        normalize: t.Optional[bool] = None,
56    ) -> t.List[str]:
57        """
58        Get the column names for a table.
59
60        Args:
61            table: the `Table` expression instance.
62            only_visible: whether to include invisible columns.
63            dialect: the SQL dialect that will be used to parse `table` if it's a string.
64            normalize: whether to normalize identifiers according to the dialect of interest.
65
66        Returns:
67            The list of column names.
68        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
70    @abc.abstractmethod
71    def get_column_type(
72        self,
73        table: exp.Table | str,
74        column: exp.Column | str,
75        dialect: DialectType = None,
76        normalize: t.Optional[bool] = None,
77    ) -> exp.DataType:
78        """
79        Get the `sqlglot.exp.DataType` type of a column in the schema.
80
81        Args:
82            table: the source table.
83            column: the target column.
84            dialect: the SQL dialect that will be used to parse `table` if it's a string.
85            normalize: whether to normalize identifiers according to the dialect of interest.
86
87        Returns:
88            The resulting column type.
89        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 91    def has_column(
 92        self,
 93        table: exp.Table | str,
 94        column: exp.Column | str,
 95        dialect: DialectType = None,
 96        normalize: t.Optional[bool] = None,
 97    ) -> bool:
 98        """
 99        Returns whether or not `column` appears in `table`'s schema.
100
101        Args:
102            table: the source table.
103            column: the target column.
104            dialect: the SQL dialect that will be used to parse `table` if it's a string.
105            normalize: whether to normalize identifiers according to the dialect of interest.
106
107        Returns:
108            True if the column appears in the schema, False otherwise.
109        """
110        name = column if isinstance(column, str) else column.name
111        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema:
126class AbstractMappingSchema:
127    def __init__(
128        self,
129        mapping: t.Optional[t.Dict] = None,
130    ) -> None:
131        self.mapping = mapping or {}
132        self.mapping_trie = new_trie(
133            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
134        )
135        self._supported_table_args: t.Tuple[str, ...] = tuple()
136
137    @property
138    def empty(self) -> bool:
139        return not self.mapping
140
141    def depth(self) -> int:
142        return dict_depth(self.mapping)
143
144    @property
145    def supported_table_args(self) -> t.Tuple[str, ...]:
146        if not self._supported_table_args and self.mapping:
147            depth = self.depth()
148
149            if not depth:  # None
150                self._supported_table_args = tuple()
151            elif 1 <= depth <= 3:
152                self._supported_table_args = TABLE_ARGS[:depth]
153            else:
154                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
155
156        return self._supported_table_args
157
158    def table_parts(self, table: exp.Table) -> t.List[str]:
159        if isinstance(table.this, exp.ReadCSV):
160            return [table.this.name]
161        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
162
163    def find(
164        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
165    ) -> t.Optional[t.Any]:
166        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
167        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
168
169        if value == TrieResult.FAILED:
170            return None
171
172        if value == TrieResult.PREFIX:
173            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
174
175            if len(possibilities) == 1:
176                parts.extend(possibilities[0])
177            else:
178                message = ", ".join(".".join(parts) for parts in possibilities)
179                if raise_on_missing:
180                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
181                return None
182
183        return self.nested_get(parts, raise_on_missing=raise_on_missing)
184
185    def nested_get(
186        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
187    ) -> t.Optional[t.Any]:
188        return nested_get(
189            d or self.mapping,
190            *zip(self.supported_table_args, reversed(parts)),
191            raise_on_missing=raise_on_missing,
192        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
127    def __init__(
128        self,
129        mapping: t.Optional[t.Dict] = None,
130    ) -> None:
131        self.mapping = mapping or {}
132        self.mapping_trie = new_trie(
133            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
134        )
135        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
def depth(self) -> int:
141    def depth(self) -> int:
142        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
158    def table_parts(self, table: exp.Table) -> t.List[str]:
159        if isinstance(table.this, exp.ReadCSV):
160            return [table.this.name]
161        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[Any]:
163    def find(
164        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
165    ) -> t.Optional[t.Any]:
166        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
167        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
168
169        if value == TrieResult.FAILED:
170            return None
171
172        if value == TrieResult.PREFIX:
173            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
174
175            if len(possibilities) == 1:
176                parts.extend(possibilities[0])
177            else:
178                message = ", ".join(".".join(parts) for parts in possibilities)
179                if raise_on_missing:
180                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
181                return None
182
183        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
185    def nested_get(
186        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
187    ) -> t.Optional[t.Any]:
188        return nested_get(
189            d or self.mapping,
190            *zip(self.supported_table_args, reversed(parts)),
191            raise_on_missing=raise_on_missing,
192        )
class MappingSchema(AbstractMappingSchema, Schema):
195class MappingSchema(AbstractMappingSchema, Schema):
196    """
197    Schema based on a nested mapping.
198
199    Args:
200        schema: Mapping in one of the following forms:
201            1. {table: {col: type}}
202            2. {db: {table: {col: type}}}
203            3. {catalog: {db: {table: {col: type}}}}
204            4. None - Tables will be added later
205        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
206            are assumed to be visible. The nesting should mirror that of the schema:
207            1. {table: set(*cols)}}
208            2. {db: {table: set(*cols)}}}
209            3. {catalog: {db: {table: set(*cols)}}}}
210        dialect: The dialect to be used for custom type mappings & parsing string arguments.
211        normalize: Whether to normalize identifier names according to the given dialect or not.
212    """
213
214    def __init__(
215        self,
216        schema: t.Optional[t.Dict] = None,
217        visible: t.Optional[t.Dict] = None,
218        dialect: DialectType = None,
219        normalize: bool = True,
220    ) -> None:
221        self.dialect = dialect
222        self.visible = visible or {}
223        self.normalize = normalize
224        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
225        self._depth = 0
226
227        super().__init__(self._normalize(schema or {}))
228
229    @classmethod
230    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
231        return MappingSchema(
232            schema=mapping_schema.mapping,
233            visible=mapping_schema.visible,
234            dialect=mapping_schema.dialect,
235            normalize=mapping_schema.normalize,
236        )
237
238    def copy(self, **kwargs) -> MappingSchema:
239        return MappingSchema(
240            **{  # type: ignore
241                "schema": self.mapping.copy(),
242                "visible": self.visible.copy(),
243                "dialect": self.dialect,
244                "normalize": self.normalize,
245                **kwargs,
246            }
247        )
248
249    def add_table(
250        self,
251        table: exp.Table | str,
252        column_mapping: t.Optional[ColumnMapping] = None,
253        dialect: DialectType = None,
254        normalize: t.Optional[bool] = None,
255        match_depth: bool = True,
256    ) -> None:
257        """
258        Register or update a table. Updates are only performed if a new column mapping is provided.
259        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
260
261        Args:
262            table: the `Table` expression instance or string representing the table.
263            column_mapping: a column mapping that describes the structure of the table.
264            dialect: the SQL dialect that will be used to parse `table` if it's a string.
265            normalize: whether to normalize identifiers according to the dialect of interest.
266            match_depth: whether to enforce that the table must match the schema's depth or not.
267        """
268        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
269
270        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
271            raise SchemaError(
272                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
273                f"schema's nesting level: {self.depth()}."
274            )
275
276        normalized_column_mapping = {
277            self._normalize_name(key, dialect=dialect, normalize=normalize): value
278            for key, value in ensure_column_mapping(column_mapping).items()
279        }
280
281        schema = self.find(normalized_table, raise_on_missing=False)
282        if schema and not normalized_column_mapping:
283            return
284
285        parts = self.table_parts(normalized_table)
286
287        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
288        new_trie([parts], self.mapping_trie)
289
290    def column_names(
291        self,
292        table: exp.Table | str,
293        only_visible: bool = False,
294        dialect: DialectType = None,
295        normalize: t.Optional[bool] = None,
296    ) -> t.List[str]:
297        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
298
299        schema = self.find(normalized_table)
300        if schema is None:
301            return []
302
303        if not only_visible or not self.visible:
304            return list(schema)
305
306        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
307        return [col for col in schema if col in visible]
308
309    def get_column_type(
310        self,
311        table: exp.Table | str,
312        column: exp.Column | str,
313        dialect: DialectType = None,
314        normalize: t.Optional[bool] = None,
315    ) -> exp.DataType:
316        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
317
318        normalized_column_name = self._normalize_name(
319            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
320        )
321
322        table_schema = self.find(normalized_table, raise_on_missing=False)
323        if table_schema:
324            column_type = table_schema.get(normalized_column_name)
325
326            if isinstance(column_type, exp.DataType):
327                return column_type
328            elif isinstance(column_type, str):
329                return self._to_data_type(column_type, dialect=dialect)
330
331        return exp.DataType.build("unknown")
332
333    def has_column(
334        self,
335        table: exp.Table | str,
336        column: exp.Column | str,
337        dialect: DialectType = None,
338        normalize: t.Optional[bool] = None,
339    ) -> bool:
340        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
341
342        normalized_column_name = self._normalize_name(
343            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
344        )
345
346        table_schema = self.find(normalized_table, raise_on_missing=False)
347        return normalized_column_name in table_schema if table_schema else False
348
349    def _normalize(self, schema: t.Dict) -> t.Dict:
350        """
351        Normalizes all identifiers in the schema.
352
353        Args:
354            schema: the schema to normalize.
355
356        Returns:
357            The normalized schema mapping.
358        """
359        normalized_mapping: t.Dict = {}
360        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
361
362        for keys in flattened_schema:
363            columns = nested_get(schema, *zip(keys, keys))
364
365            if not isinstance(columns, dict):
366                raise SchemaError(
367                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
368                )
369
370            normalized_keys = [
371                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
372            ]
373            for column_name, column_type in columns.items():
374                nested_set(
375                    normalized_mapping,
376                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
377                    column_type,
378                )
379
380        return normalized_mapping
381
382    def _normalize_table(
383        self,
384        table: exp.Table | str,
385        dialect: DialectType = None,
386        normalize: t.Optional[bool] = None,
387    ) -> exp.Table:
388        normalized_table = exp.maybe_parse(
389            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
390        )
391
392        for arg in TABLE_ARGS:
393            value = normalized_table.args.get(arg)
394            if isinstance(value, (str, exp.Identifier)):
395                normalized_table.set(
396                    arg,
397                    exp.to_identifier(
398                        self._normalize_name(
399                            value, dialect=dialect, is_table=True, normalize=normalize
400                        )
401                    ),
402                )
403
404        return normalized_table
405
406    def _normalize_name(
407        self,
408        name: str | exp.Identifier,
409        dialect: DialectType = None,
410        is_table: bool = False,
411        normalize: t.Optional[bool] = None,
412    ) -> str:
413        return normalize_name(
414            name,
415            dialect=dialect or self.dialect,
416            is_table=is_table,
417            normalize=self.normalize if normalize is None else normalize,
418        )
419
420    def depth(self) -> int:
421        if not self.empty and not self._depth:
422            # The columns themselves are a mapping, but we don't want to include those
423            self._depth = super().depth() - 1
424        return self._depth
425
426    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
427        """
428        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
429
430        Args:
431            schema_type: the type we want to convert.
432            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
433
434        Returns:
435            The resulting expression type.
436        """
437        if schema_type not in self._type_mapping_cache:
438            dialect = dialect or self.dialect
439            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
440
441            try:
442                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
443                self._type_mapping_cache[schema_type] = expression
444            except AttributeError:
445                in_dialect = f" in dialect {dialect}" if dialect else ""
446                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
447
448        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
214    def __init__(
215        self,
216        schema: t.Optional[t.Dict] = None,
217        visible: t.Optional[t.Dict] = None,
218        dialect: DialectType = None,
219        normalize: bool = True,
220    ) -> None:
221        self.dialect = dialect
222        self.visible = visible or {}
223        self.normalize = normalize
224        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
225        self._depth = 0
226
227        super().__init__(self._normalize(schema or {}))
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
229    @classmethod
230    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
231        return MappingSchema(
232            schema=mapping_schema.mapping,
233            visible=mapping_schema.visible,
234            dialect=mapping_schema.dialect,
235            normalize=mapping_schema.normalize,
236        )
def copy(self, **kwargs) -> MappingSchema:
238    def copy(self, **kwargs) -> MappingSchema:
239        return MappingSchema(
240            **{  # type: ignore
241                "schema": self.mapping.copy(),
242                "visible": self.visible.copy(),
243                "dialect": self.dialect,
244                "normalize": self.normalize,
245                **kwargs,
246            }
247        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
249    def add_table(
250        self,
251        table: exp.Table | str,
252        column_mapping: t.Optional[ColumnMapping] = None,
253        dialect: DialectType = None,
254        normalize: t.Optional[bool] = None,
255        match_depth: bool = True,
256    ) -> None:
257        """
258        Register or update a table. Updates are only performed if a new column mapping is provided.
259        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
260
261        Args:
262            table: the `Table` expression instance or string representing the table.
263            column_mapping: a column mapping that describes the structure of the table.
264            dialect: the SQL dialect that will be used to parse `table` if it's a string.
265            normalize: whether to normalize identifiers according to the dialect of interest.
266            match_depth: whether to enforce that the table must match the schema's depth or not.
267        """
268        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
269
270        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
271            raise SchemaError(
272                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
273                f"schema's nesting level: {self.depth()}."
274            )
275
276        normalized_column_mapping = {
277            self._normalize_name(key, dialect=dialect, normalize=normalize): value
278            for key, value in ensure_column_mapping(column_mapping).items()
279        }
280
281        schema = self.find(normalized_table, raise_on_missing=False)
282        if schema and not normalized_column_mapping:
283            return
284
285        parts = self.table_parts(normalized_table)
286
287        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
288        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
290    def column_names(
291        self,
292        table: exp.Table | str,
293        only_visible: bool = False,
294        dialect: DialectType = None,
295        normalize: t.Optional[bool] = None,
296    ) -> t.List[str]:
297        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
298
299        schema = self.find(normalized_table)
300        if schema is None:
301            return []
302
303        if not only_visible or not self.visible:
304            return list(schema)
305
306        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
307        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
309    def get_column_type(
310        self,
311        table: exp.Table | str,
312        column: exp.Column | str,
313        dialect: DialectType = None,
314        normalize: t.Optional[bool] = None,
315    ) -> exp.DataType:
316        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
317
318        normalized_column_name = self._normalize_name(
319            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
320        )
321
322        table_schema = self.find(normalized_table, raise_on_missing=False)
323        if table_schema:
324            column_type = table_schema.get(normalized_column_name)
325
326            if isinstance(column_type, exp.DataType):
327                return column_type
328            elif isinstance(column_type, str):
329                return self._to_data_type(column_type, dialect=dialect)
330
331        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
333    def has_column(
334        self,
335        table: exp.Table | str,
336        column: exp.Column | str,
337        dialect: DialectType = None,
338        normalize: t.Optional[bool] = None,
339    ) -> bool:
340        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
341
342        normalized_column_name = self._normalize_name(
343            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
344        )
345
346        table_schema = self.find(normalized_table, raise_on_missing=False)
347        return normalized_column_name in table_schema if table_schema else False

Returns whether or not column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
420    def depth(self) -> int:
421        if not self.empty and not self._depth:
422            # The columns themselves are a mapping, but we don't want to include those
423            self._depth = super().depth() - 1
424        return self._depth
def normalize_name( name: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> str:
451def normalize_name(
452    name: str | exp.Identifier,
453    dialect: DialectType = None,
454    is_table: bool = False,
455    normalize: t.Optional[bool] = True,
456) -> str:
457    try:
458        identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
459    except ParseError:
460        return name if isinstance(name, str) else name.name
461
462    name = identifier.name
463    if not normalize:
464        return name
465
466    # This can be useful for normalize_identifier
467    identifier.meta["is_table"] = is_table
468    return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
471def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
472    if isinstance(schema, Schema):
473        return schema
474
475    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
478def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
479    if mapping is None:
480        return {}
481    elif isinstance(mapping, dict):
482        return mapping
483    elif isinstance(mapping, str):
484        col_name_type_strs = [x.strip() for x in mapping.split(",")]
485        return {
486            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
487            for name_type_str in col_name_type_strs
488        }
489    # Check if mapping looks like a DataFrame StructType
490    elif hasattr(mapping, "simpleString"):
491        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
492    elif isinstance(mapping, list):
493        return {x.strip(): None for x in mapping}
494
495    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
498def flatten_schema(
499    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
500) -> t.List[t.List[str]]:
501    tables = []
502    keys = keys or []
503
504    for k, v in schema.items():
505        if depth >= 2:
506            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
507        elif depth == 1:
508            tables.append(keys + [k])
509
510    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
513def nested_get(
514    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
515) -> t.Optional[t.Any]:
516    """
517    Get a value for a nested dictionary.
518
519    Args:
520        d: the dictionary to search.
521        *path: tuples of (name, key), where:
522            `key` is the key in the dictionary to get.
523            `name` is a string to use in the error if `key` isn't found.
524
525    Returns:
526        The value or None if it doesn't exist.
527    """
528    for name, key in path:
529        d = d.get(key)  # type: ignore
530        if d is None:
531            if raise_on_missing:
532                name = "table" if name == "this" else name
533                raise ValueError(f"Unknown {name}: {key}")
534            return None
535
536    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
539def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
540    """
541    In-place set a value for a nested dictionary
542
543    Example:
544        >>> nested_set({}, ["top_key", "second_key"], "value")
545        {'top_key': {'second_key': 'value'}}
546
547        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
548        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
549
550    Args:
551        d: dictionary to update.
552        keys: the keys that makeup the path to `value`.
553        value: the value to set in the dictionary for the given key path.
554
555    Returns:
556        The (possibly) updated dictionary.
557    """
558    if not keys:
559        return d
560
561    if len(keys) == 1:
562        d[keys[0]] = value
563        return d
564
565    subd = d
566    for key in keys[:-1]:
567        if key not in subd:
568            subd = subd.setdefault(key, {})
569        else:
570            subd = subd[key]
571
572    subd[keys[-1]] = value
573    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.