Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15    ColumnMapping = t.Union[t.Dict, str, t.List]
 16
 17
 18class Schema(abc.ABC):
 19    """Abstract base class for database schemas"""
 20
 21    dialect: DialectType
 22
 23    @abc.abstractmethod
 24    def add_table(
 25        self,
 26        table: exp.Table | str,
 27        column_mapping: t.Optional[ColumnMapping] = None,
 28        dialect: DialectType = None,
 29        normalize: t.Optional[bool] = None,
 30        match_depth: bool = True,
 31    ) -> None:
 32        """
 33        Register or update a table. Some implementing classes may require column information to also be provided.
 34        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 35
 36        Args:
 37            table: the `Table` expression instance or string representing the table.
 38            column_mapping: a column mapping that describes the structure of the table.
 39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 40            normalize: whether to normalize identifiers according to the dialect of interest.
 41            match_depth: whether to enforce that the table must match the schema's depth or not.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50        normalize: t.Optional[bool] = None,
 51    ) -> t.Sequence[str]:
 52        """
 53        Get the column names for a table.
 54
 55        Args:
 56            table: the `Table` expression instance.
 57            only_visible: whether to include invisible columns.
 58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 59            normalize: whether to normalize identifiers according to the dialect of interest.
 60
 61        Returns:
 62            The sequence of column names.
 63        """
 64
 65    @abc.abstractmethod
 66    def get_column_type(
 67        self,
 68        table: exp.Table | str,
 69        column: exp.Column | str,
 70        dialect: DialectType = None,
 71        normalize: t.Optional[bool] = None,
 72    ) -> exp.DataType:
 73        """
 74        Get the `sqlglot.exp.DataType` type of a column in the schema.
 75
 76        Args:
 77            table: the source table.
 78            column: the target column.
 79            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 80            normalize: whether to normalize identifiers according to the dialect of interest.
 81
 82        Returns:
 83            The resulting column type.
 84        """
 85
 86    def has_column(
 87        self,
 88        table: exp.Table | str,
 89        column: exp.Column | str,
 90        dialect: DialectType = None,
 91        normalize: t.Optional[bool] = None,
 92    ) -> bool:
 93        """
 94        Returns whether `column` appears in `table`'s schema.
 95
 96        Args:
 97            table: the source table.
 98            column: the target column.
 99            dialect: the SQL dialect that will be used to parse `table` if it's a string.
100            normalize: whether to normalize identifiers according to the dialect of interest.
101
102        Returns:
103            True if the column appears in the schema, False otherwise.
104        """
105        name = column if isinstance(column, str) else column.name
106        return name in self.column_names(table, dialect=dialect, normalize=normalize)
107
108    @property
109    @abc.abstractmethod
110    def supported_table_args(self) -> t.Tuple[str, ...]:
111        """
112        Table arguments this schema support, e.g. `("this", "db", "catalog")`
113        """
114
115    @property
116    def empty(self) -> bool:
117        """Returns whether the schema is empty."""
118        return True
119
120
121class AbstractMappingSchema:
122    def __init__(
123        self,
124        mapping: t.Optional[t.Dict] = None,
125    ) -> None:
126        self.mapping = mapping or {}
127        self.mapping_trie = new_trie(
128            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
129        )
130        self._supported_table_args: t.Tuple[str, ...] = tuple()
131
132    @property
133    def empty(self) -> bool:
134        return not self.mapping
135
136    def depth(self) -> int:
137        return dict_depth(self.mapping)
138
139    @property
140    def supported_table_args(self) -> t.Tuple[str, ...]:
141        if not self._supported_table_args and self.mapping:
142            depth = self.depth()
143
144            if not depth:  # None
145                self._supported_table_args = tuple()
146            elif 1 <= depth <= 3:
147                self._supported_table_args = exp.TABLE_PARTS[:depth]
148            else:
149                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
150
151        return self._supported_table_args
152
153    def table_parts(self, table: exp.Table) -> t.List[str]:
154        if isinstance(table.this, exp.ReadCSV):
155            return [table.this.name]
156        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
157
158    def find(
159        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
160    ) -> t.Optional[t.Any]:
161        """
162        Returns the schema of a given table.
163
164        Args:
165            table: the target table.
166            raise_on_missing: whether to raise in case the schema is not found.
167            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
168
169        Returns:
170            The schema of the target table.
171        """
172        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
173        value, trie = in_trie(self.mapping_trie, parts)
174
175        if value == TrieResult.FAILED:
176            return None
177
178        if value == TrieResult.PREFIX:
179            possibilities = flatten_schema(trie)
180
181            if len(possibilities) == 1:
182                parts.extend(possibilities[0])
183            else:
184                message = ", ".join(".".join(parts) for parts in possibilities)
185                if raise_on_missing:
186                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
187                return None
188
189        return self.nested_get(parts, raise_on_missing=raise_on_missing)
190
191    def nested_get(
192        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
193    ) -> t.Optional[t.Any]:
194        return nested_get(
195            d or self.mapping,
196            *zip(self.supported_table_args, reversed(parts)),
197            raise_on_missing=raise_on_missing,
198        )
199
200
201class MappingSchema(AbstractMappingSchema, Schema):
202    """
203    Schema based on a nested mapping.
204
205    Args:
206        schema: Mapping in one of the following forms:
207            1. {table: {col: type}}
208            2. {db: {table: {col: type}}}
209            3. {catalog: {db: {table: {col: type}}}}
210            4. None - Tables will be added later
211        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
212            are assumed to be visible. The nesting should mirror that of the schema:
213            1. {table: set(*cols)}}
214            2. {db: {table: set(*cols)}}}
215            3. {catalog: {db: {table: set(*cols)}}}}
216        dialect: The dialect to be used for custom type mappings & parsing string arguments.
217        normalize: Whether to normalize identifier names according to the given dialect or not.
218    """
219
220    def __init__(
221        self,
222        schema: t.Optional[t.Dict] = None,
223        visible: t.Optional[t.Dict] = None,
224        dialect: DialectType = None,
225        normalize: bool = True,
226    ) -> None:
227        self.dialect = dialect
228        self.visible = {} if visible is None else visible
229        self.normalize = normalize
230        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
231        self._depth = 0
232        schema = {} if schema is None else schema
233
234        super().__init__(self._normalize(schema) if self.normalize else schema)
235
236    @classmethod
237    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
238        return MappingSchema(
239            schema=mapping_schema.mapping,
240            visible=mapping_schema.visible,
241            dialect=mapping_schema.dialect,
242            normalize=mapping_schema.normalize,
243        )
244
245    def find(
246        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
247    ) -> t.Optional[t.Any]:
248        schema = super().find(
249            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
250        )
251        if ensure_data_types and isinstance(schema, dict):
252            schema = {
253                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
254                for col, dtype in schema.items()
255            }
256
257        return schema
258
259    def copy(self, **kwargs) -> MappingSchema:
260        return MappingSchema(
261            **{  # type: ignore
262                "schema": self.mapping.copy(),
263                "visible": self.visible.copy(),
264                "dialect": self.dialect,
265                "normalize": self.normalize,
266                **kwargs,
267            }
268        )
269
270    def add_table(
271        self,
272        table: exp.Table | str,
273        column_mapping: t.Optional[ColumnMapping] = None,
274        dialect: DialectType = None,
275        normalize: t.Optional[bool] = None,
276        match_depth: bool = True,
277    ) -> None:
278        """
279        Register or update a table. Updates are only performed if a new column mapping is provided.
280        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
281
282        Args:
283            table: the `Table` expression instance or string representing the table.
284            column_mapping: a column mapping that describes the structure of the table.
285            dialect: the SQL dialect that will be used to parse `table` if it's a string.
286            normalize: whether to normalize identifiers according to the dialect of interest.
287            match_depth: whether to enforce that the table must match the schema's depth or not.
288        """
289        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
290
291        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
292            raise SchemaError(
293                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
294                f"schema's nesting level: {self.depth()}."
295            )
296
297        normalized_column_mapping = {
298            self._normalize_name(key, dialect=dialect, normalize=normalize): value
299            for key, value in ensure_column_mapping(column_mapping).items()
300        }
301
302        schema = self.find(normalized_table, raise_on_missing=False)
303        if schema and not normalized_column_mapping:
304            return
305
306        parts = self.table_parts(normalized_table)
307
308        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
309        new_trie([parts], self.mapping_trie)
310
311    def column_names(
312        self,
313        table: exp.Table | str,
314        only_visible: bool = False,
315        dialect: DialectType = None,
316        normalize: t.Optional[bool] = None,
317    ) -> t.List[str]:
318        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
319
320        schema = self.find(normalized_table)
321        if schema is None:
322            return []
323
324        if not only_visible or not self.visible:
325            return list(schema)
326
327        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
328        return [col for col in schema if col in visible]
329
330    def get_column_type(
331        self,
332        table: exp.Table | str,
333        column: exp.Column | str,
334        dialect: DialectType = None,
335        normalize: t.Optional[bool] = None,
336    ) -> exp.DataType:
337        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
338
339        normalized_column_name = self._normalize_name(
340            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
341        )
342
343        table_schema = self.find(normalized_table, raise_on_missing=False)
344        if table_schema:
345            column_type = table_schema.get(normalized_column_name)
346
347            if isinstance(column_type, exp.DataType):
348                return column_type
349            elif isinstance(column_type, str):
350                return self._to_data_type(column_type, dialect=dialect)
351
352        return exp.DataType.build("unknown")
353
354    def has_column(
355        self,
356        table: exp.Table | str,
357        column: exp.Column | str,
358        dialect: DialectType = None,
359        normalize: t.Optional[bool] = None,
360    ) -> bool:
361        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
362
363        normalized_column_name = self._normalize_name(
364            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
365        )
366
367        table_schema = self.find(normalized_table, raise_on_missing=False)
368        return normalized_column_name in table_schema if table_schema else False
369
370    def _normalize(self, schema: t.Dict) -> t.Dict:
371        """
372        Normalizes all identifiers in the schema.
373
374        Args:
375            schema: the schema to normalize.
376
377        Returns:
378            The normalized schema mapping.
379        """
380        normalized_mapping: t.Dict = {}
381        flattened_schema = flatten_schema(schema)
382        error_msg = "Table {} must match the schema's nesting level: {}."
383
384        for keys in flattened_schema:
385            columns = nested_get(schema, *zip(keys, keys))
386
387            if not isinstance(columns, dict):
388                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
389            if isinstance(first(columns.values()), dict):
390                raise SchemaError(
391                    error_msg.format(
392                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
393                    ),
394                )
395
396            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
397            for column_name, column_type in columns.items():
398                nested_set(
399                    normalized_mapping,
400                    normalized_keys + [self._normalize_name(column_name)],
401                    column_type,
402                )
403
404        return normalized_mapping
405
406    def _normalize_table(
407        self,
408        table: exp.Table | str,
409        dialect: DialectType = None,
410        normalize: t.Optional[bool] = None,
411    ) -> exp.Table:
412        dialect = dialect or self.dialect
413        normalize = self.normalize if normalize is None else normalize
414
415        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
416
417        if normalize:
418            for arg in exp.TABLE_PARTS:
419                value = normalized_table.args.get(arg)
420                if isinstance(value, exp.Identifier):
421                    normalized_table.set(
422                        arg,
423                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
424                    )
425
426        return normalized_table
427
428    def _normalize_name(
429        self,
430        name: str | exp.Identifier,
431        dialect: DialectType = None,
432        is_table: bool = False,
433        normalize: t.Optional[bool] = None,
434    ) -> str:
435        return normalize_name(
436            name,
437            dialect=dialect or self.dialect,
438            is_table=is_table,
439            normalize=self.normalize if normalize is None else normalize,
440        ).name
441
442    def depth(self) -> int:
443        if not self.empty and not self._depth:
444            # The columns themselves are a mapping, but we don't want to include those
445            self._depth = super().depth() - 1
446        return self._depth
447
448    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
449        """
450        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
451
452        Args:
453            schema_type: the type we want to convert.
454            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
455
456        Returns:
457            The resulting expression type.
458        """
459        if schema_type not in self._type_mapping_cache:
460            dialect = dialect or self.dialect
461            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
462
463            try:
464                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
465                self._type_mapping_cache[schema_type] = expression
466            except AttributeError:
467                in_dialect = f" in dialect {dialect}" if dialect else ""
468                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
469
470        return self._type_mapping_cache[schema_type]
471
472
473def normalize_name(
474    identifier: str | exp.Identifier,
475    dialect: DialectType = None,
476    is_table: bool = False,
477    normalize: t.Optional[bool] = True,
478) -> exp.Identifier:
479    if isinstance(identifier, str):
480        identifier = exp.parse_identifier(identifier, dialect=dialect)
481
482    if not normalize:
483        return identifier
484
485    # this is used for normalize_identifier, bigquery has special rules pertaining tables
486    identifier.meta["is_table"] = is_table
487    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
488
489
490def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
491    if isinstance(schema, Schema):
492        return schema
493
494    return MappingSchema(schema, **kwargs)
495
496
497def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
498    if mapping is None:
499        return {}
500    elif isinstance(mapping, dict):
501        return mapping
502    elif isinstance(mapping, str):
503        col_name_type_strs = [x.strip() for x in mapping.split(",")]
504        return {
505            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
506            for name_type_str in col_name_type_strs
507        }
508    elif isinstance(mapping, list):
509        return {x.strip(): None for x in mapping}
510
511    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
512
513
514def flatten_schema(
515    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
516) -> t.List[t.List[str]]:
517    tables = []
518    keys = keys or []
519    depth = dict_depth(schema) - 1 if depth is None else depth
520
521    for k, v in schema.items():
522        if depth == 1 or not isinstance(v, dict):
523            tables.append(keys + [k])
524        elif depth >= 2:
525            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
526
527    return tables
528
529
530def nested_get(
531    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
532) -> t.Optional[t.Any]:
533    """
534    Get a value for a nested dictionary.
535
536    Args:
537        d: the dictionary to search.
538        *path: tuples of (name, key), where:
539            `key` is the key in the dictionary to get.
540            `name` is a string to use in the error if `key` isn't found.
541
542    Returns:
543        The value or None if it doesn't exist.
544    """
545    for name, key in path:
546        d = d.get(key)  # type: ignore
547        if d is None:
548            if raise_on_missing:
549                name = "table" if name == "this" else name
550                raise ValueError(f"Unknown {name}: {key}")
551            return None
552
553    return d
554
555
556def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
557    """
558    In-place set a value for a nested dictionary
559
560    Example:
561        >>> nested_set({}, ["top_key", "second_key"], "value")
562        {'top_key': {'second_key': 'value'}}
563
564        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
565        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
566
567    Args:
568        d: dictionary to update.
569        keys: the keys that makeup the path to `value`.
570        value: the value to set in the dictionary for the given key path.
571
572    Returns:
573        The (possibly) updated dictionary.
574    """
575    if not keys:
576        return d
577
578    if len(keys) == 1:
579        d[keys[0]] = value
580        return d
581
582    subd = d
583    for key in keys[:-1]:
584        if key not in subd:
585            subd = subd.setdefault(key, {})
586        else:
587            subd = subd[key]
588
589    subd[keys[-1]] = value
590    return d
class Schema(abc.ABC):
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.Sequence[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The sequence of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
24    @abc.abstractmethod
25    def add_table(
26        self,
27        table: exp.Table | str,
28        column_mapping: t.Optional[ColumnMapping] = None,
29        dialect: DialectType = None,
30        normalize: t.Optional[bool] = None,
31        match_depth: bool = True,
32    ) -> None:
33        """
34        Register or update a table. Some implementing classes may require column information to also be provided.
35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
36
37        Args:
38            table: the `Table` expression instance or string representing the table.
39            column_mapping: a column mapping that describes the structure of the table.
40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
41            normalize: whether to normalize identifiers according to the dialect of interest.
42            match_depth: whether to enforce that the table must match the schema's depth or not.
43        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51        normalize: t.Optional[bool] = None,
52    ) -> t.Sequence[str]:
53        """
54        Get the column names for a table.
55
56        Args:
57            table: the `Table` expression instance.
58            only_visible: whether to include invisible columns.
59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
60            normalize: whether to normalize identifiers according to the dialect of interest.
61
62        Returns:
63            The sequence of column names.
64        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
66    @abc.abstractmethod
67    def get_column_type(
68        self,
69        table: exp.Table | str,
70        column: exp.Column | str,
71        dialect: DialectType = None,
72        normalize: t.Optional[bool] = None,
73    ) -> exp.DataType:
74        """
75        Get the `sqlglot.exp.DataType` type of a column in the schema.
76
77        Args:
78            table: the source table.
79            column: the target column.
80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
81            normalize: whether to normalize identifiers according to the dialect of interest.
82
83        Returns:
84            The resulting column type.
85        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(
160        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
161    ) -> t.Optional[t.Any]:
162        """
163        Returns the schema of a given table.
164
165        Args:
166            table: the target table.
167            raise_on_missing: whether to raise in case the schema is not found.
168            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
169
170        Returns:
171            The schema of the target table.
172        """
173        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
174        value, trie = in_trie(self.mapping_trie, parts)
175
176        if value == TrieResult.FAILED:
177            return None
178
179        if value == TrieResult.PREFIX:
180            possibilities = flatten_schema(trie)
181
182            if len(possibilities) == 1:
183                parts.extend(possibilities[0])
184            else:
185                message = ", ".join(".".join(parts) for parts in possibilities)
186                if raise_on_missing:
187                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
188                return None
189
190        return self.nested_get(parts, raise_on_missing=raise_on_missing)
191
192    def nested_get(
193        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
194    ) -> t.Optional[t.Any]:
195        return nested_get(
196            d or self.mapping,
197            *zip(self.supported_table_args, reversed(parts)),
198            raise_on_missing=raise_on_missing,
199        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
def depth(self) -> int:
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
159    def find(
160        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
161    ) -> t.Optional[t.Any]:
162        """
163        Returns the schema of a given table.
164
165        Args:
166            table: the target table.
167            raise_on_missing: whether to raise in case the schema is not found.
168            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
169
170        Returns:
171            The schema of the target table.
172        """
173        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
174        value, trie = in_trie(self.mapping_trie, parts)
175
176        if value == TrieResult.FAILED:
177            return None
178
179        if value == TrieResult.PREFIX:
180            possibilities = flatten_schema(trie)
181
182            if len(possibilities) == 1:
183                parts.extend(possibilities[0])
184            else:
185                message = ", ".join(".".join(parts) for parts in possibilities)
186                if raise_on_missing:
187                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
188                return None
189
190        return self.nested_get(parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
192    def nested_get(
193        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
194    ) -> t.Optional[t.Any]:
195        return nested_get(
196            d or self.mapping,
197            *zip(self.supported_table_args, reversed(parts)),
198            raise_on_missing=raise_on_missing,
199        )
class MappingSchema(AbstractMappingSchema, Schema):
202class MappingSchema(AbstractMappingSchema, Schema):
203    """
204    Schema based on a nested mapping.
205
206    Args:
207        schema: Mapping in one of the following forms:
208            1. {table: {col: type}}
209            2. {db: {table: {col: type}}}
210            3. {catalog: {db: {table: {col: type}}}}
211            4. None - Tables will be added later
212        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
213            are assumed to be visible. The nesting should mirror that of the schema:
214            1. {table: set(*cols)}}
215            2. {db: {table: set(*cols)}}}
216            3. {catalog: {db: {table: set(*cols)}}}}
217        dialect: The dialect to be used for custom type mappings & parsing string arguments.
218        normalize: Whether to normalize identifier names according to the given dialect or not.
219    """
220
221    def __init__(
222        self,
223        schema: t.Optional[t.Dict] = None,
224        visible: t.Optional[t.Dict] = None,
225        dialect: DialectType = None,
226        normalize: bool = True,
227    ) -> None:
228        self.dialect = dialect
229        self.visible = {} if visible is None else visible
230        self.normalize = normalize
231        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
232        self._depth = 0
233        schema = {} if schema is None else schema
234
235        super().__init__(self._normalize(schema) if self.normalize else schema)
236
237    @classmethod
238    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
239        return MappingSchema(
240            schema=mapping_schema.mapping,
241            visible=mapping_schema.visible,
242            dialect=mapping_schema.dialect,
243            normalize=mapping_schema.normalize,
244        )
245
246    def find(
247        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
248    ) -> t.Optional[t.Any]:
249        schema = super().find(
250            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
251        )
252        if ensure_data_types and isinstance(schema, dict):
253            schema = {
254                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
255                for col, dtype in schema.items()
256            }
257
258        return schema
259
260    def copy(self, **kwargs) -> MappingSchema:
261        return MappingSchema(
262            **{  # type: ignore
263                "schema": self.mapping.copy(),
264                "visible": self.visible.copy(),
265                "dialect": self.dialect,
266                "normalize": self.normalize,
267                **kwargs,
268            }
269        )
270
271    def add_table(
272        self,
273        table: exp.Table | str,
274        column_mapping: t.Optional[ColumnMapping] = None,
275        dialect: DialectType = None,
276        normalize: t.Optional[bool] = None,
277        match_depth: bool = True,
278    ) -> None:
279        """
280        Register or update a table. Updates are only performed if a new column mapping is provided.
281        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
282
283        Args:
284            table: the `Table` expression instance or string representing the table.
285            column_mapping: a column mapping that describes the structure of the table.
286            dialect: the SQL dialect that will be used to parse `table` if it's a string.
287            normalize: whether to normalize identifiers according to the dialect of interest.
288            match_depth: whether to enforce that the table must match the schema's depth or not.
289        """
290        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
291
292        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
293            raise SchemaError(
294                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
295                f"schema's nesting level: {self.depth()}."
296            )
297
298        normalized_column_mapping = {
299            self._normalize_name(key, dialect=dialect, normalize=normalize): value
300            for key, value in ensure_column_mapping(column_mapping).items()
301        }
302
303        schema = self.find(normalized_table, raise_on_missing=False)
304        if schema and not normalized_column_mapping:
305            return
306
307        parts = self.table_parts(normalized_table)
308
309        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
310        new_trie([parts], self.mapping_trie)
311
312    def column_names(
313        self,
314        table: exp.Table | str,
315        only_visible: bool = False,
316        dialect: DialectType = None,
317        normalize: t.Optional[bool] = None,
318    ) -> t.List[str]:
319        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
320
321        schema = self.find(normalized_table)
322        if schema is None:
323            return []
324
325        if not only_visible or not self.visible:
326            return list(schema)
327
328        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
329        return [col for col in schema if col in visible]
330
331    def get_column_type(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> exp.DataType:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        if table_schema:
346            column_type = table_schema.get(normalized_column_name)
347
348            if isinstance(column_type, exp.DataType):
349                return column_type
350            elif isinstance(column_type, str):
351                return self._to_data_type(column_type, dialect=dialect)
352
353        return exp.DataType.build("unknown")
354
355    def has_column(
356        self,
357        table: exp.Table | str,
358        column: exp.Column | str,
359        dialect: DialectType = None,
360        normalize: t.Optional[bool] = None,
361    ) -> bool:
362        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
363
364        normalized_column_name = self._normalize_name(
365            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
366        )
367
368        table_schema = self.find(normalized_table, raise_on_missing=False)
369        return normalized_column_name in table_schema if table_schema else False
370
371    def _normalize(self, schema: t.Dict) -> t.Dict:
372        """
373        Normalizes all identifiers in the schema.
374
375        Args:
376            schema: the schema to normalize.
377
378        Returns:
379            The normalized schema mapping.
380        """
381        normalized_mapping: t.Dict = {}
382        flattened_schema = flatten_schema(schema)
383        error_msg = "Table {} must match the schema's nesting level: {}."
384
385        for keys in flattened_schema:
386            columns = nested_get(schema, *zip(keys, keys))
387
388            if not isinstance(columns, dict):
389                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
390            if isinstance(first(columns.values()), dict):
391                raise SchemaError(
392                    error_msg.format(
393                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
394                    ),
395                )
396
397            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
398            for column_name, column_type in columns.items():
399                nested_set(
400                    normalized_mapping,
401                    normalized_keys + [self._normalize_name(column_name)],
402                    column_type,
403                )
404
405        return normalized_mapping
406
407    def _normalize_table(
408        self,
409        table: exp.Table | str,
410        dialect: DialectType = None,
411        normalize: t.Optional[bool] = None,
412    ) -> exp.Table:
413        dialect = dialect or self.dialect
414        normalize = self.normalize if normalize is None else normalize
415
416        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
417
418        if normalize:
419            for arg in exp.TABLE_PARTS:
420                value = normalized_table.args.get(arg)
421                if isinstance(value, exp.Identifier):
422                    normalized_table.set(
423                        arg,
424                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
425                    )
426
427        return normalized_table
428
429    def _normalize_name(
430        self,
431        name: str | exp.Identifier,
432        dialect: DialectType = None,
433        is_table: bool = False,
434        normalize: t.Optional[bool] = None,
435    ) -> str:
436        return normalize_name(
437            name,
438            dialect=dialect or self.dialect,
439            is_table=is_table,
440            normalize=self.normalize if normalize is None else normalize,
441        ).name
442
443    def depth(self) -> int:
444        if not self.empty and not self._depth:
445            # The columns themselves are a mapping, but we don't want to include those
446            self._depth = super().depth() - 1
447        return self._depth
448
449    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
450        """
451        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
452
453        Args:
454            schema_type: the type we want to convert.
455            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
456
457        Returns:
458            The resulting expression type.
459        """
460        if schema_type not in self._type_mapping_cache:
461            dialect = dialect or self.dialect
462            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
463
464            try:
465                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
466                self._type_mapping_cache[schema_type] = expression
467            except AttributeError:
468                in_dialect = f" in dialect {dialect}" if dialect else ""
469                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
470
471        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
221    def __init__(
222        self,
223        schema: t.Optional[t.Dict] = None,
224        visible: t.Optional[t.Dict] = None,
225        dialect: DialectType = None,
226        normalize: bool = True,
227    ) -> None:
228        self.dialect = dialect
229        self.visible = {} if visible is None else visible
230        self.normalize = normalize
231        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
232        self._depth = 0
233        schema = {} if schema is None else schema
234
235        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
237    @classmethod
238    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
239        return MappingSchema(
240            schema=mapping_schema.mapping,
241            visible=mapping_schema.visible,
242            dialect=mapping_schema.dialect,
243            normalize=mapping_schema.normalize,
244        )
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
246    def find(
247        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
248    ) -> t.Optional[t.Any]:
249        schema = super().find(
250            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
251        )
252        if ensure_data_types and isinstance(schema, dict):
253            schema = {
254                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
255                for col, dtype in schema.items()
256            }
257
258        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy(self, **kwargs) -> MappingSchema:
260    def copy(self, **kwargs) -> MappingSchema:
261        return MappingSchema(
262            **{  # type: ignore
263                "schema": self.mapping.copy(),
264                "visible": self.visible.copy(),
265                "dialect": self.dialect,
266                "normalize": self.normalize,
267                **kwargs,
268            }
269        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
271    def add_table(
272        self,
273        table: exp.Table | str,
274        column_mapping: t.Optional[ColumnMapping] = None,
275        dialect: DialectType = None,
276        normalize: t.Optional[bool] = None,
277        match_depth: bool = True,
278    ) -> None:
279        """
280        Register or update a table. Updates are only performed if a new column mapping is provided.
281        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
282
283        Args:
284            table: the `Table` expression instance or string representing the table.
285            column_mapping: a column mapping that describes the structure of the table.
286            dialect: the SQL dialect that will be used to parse `table` if it's a string.
287            normalize: whether to normalize identifiers according to the dialect of interest.
288            match_depth: whether to enforce that the table must match the schema's depth or not.
289        """
290        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
291
292        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
293            raise SchemaError(
294                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
295                f"schema's nesting level: {self.depth()}."
296            )
297
298        normalized_column_mapping = {
299            self._normalize_name(key, dialect=dialect, normalize=normalize): value
300            for key, value in ensure_column_mapping(column_mapping).items()
301        }
302
303        schema = self.find(normalized_table, raise_on_missing=False)
304        if schema and not normalized_column_mapping:
305            return
306
307        parts = self.table_parts(normalized_table)
308
309        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
310        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
312    def column_names(
313        self,
314        table: exp.Table | str,
315        only_visible: bool = False,
316        dialect: DialectType = None,
317        normalize: t.Optional[bool] = None,
318    ) -> t.List[str]:
319        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
320
321        schema = self.find(normalized_table)
322        if schema is None:
323            return []
324
325        if not only_visible or not self.visible:
326            return list(schema)
327
328        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
329        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
331    def get_column_type(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> exp.DataType:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        if table_schema:
346            column_type = table_schema.get(normalized_column_name)
347
348            if isinstance(column_type, exp.DataType):
349                return column_type
350            elif isinstance(column_type, str):
351                return self._to_data_type(column_type, dialect=dialect)
352
353        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
355    def has_column(
356        self,
357        table: exp.Table | str,
358        column: exp.Column | str,
359        dialect: DialectType = None,
360        normalize: t.Optional[bool] = None,
361    ) -> bool:
362        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
363
364        normalized_column_name = self._normalize_name(
365            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
366        )
367
368        table_schema = self.find(normalized_table, raise_on_missing=False)
369        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
443    def depth(self) -> int:
444        if not self.empty and not self._depth:
445            # The columns themselves are a mapping, but we don't want to include those
446            self._depth = super().depth() - 1
447        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
474def normalize_name(
475    identifier: str | exp.Identifier,
476    dialect: DialectType = None,
477    is_table: bool = False,
478    normalize: t.Optional[bool] = True,
479) -> exp.Identifier:
480    if isinstance(identifier, str):
481        identifier = exp.parse_identifier(identifier, dialect=dialect)
482
483    if not normalize:
484        return identifier
485
486    # this is used for normalize_identifier, bigquery has special rules pertaining tables
487    identifier.meta["is_table"] = is_table
488    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
491def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
492    if isinstance(schema, Schema):
493        return schema
494
495    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[Dict, str, List, NoneType]) -> Dict:
498def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
499    if mapping is None:
500        return {}
501    elif isinstance(mapping, dict):
502        return mapping
503    elif isinstance(mapping, str):
504        col_name_type_strs = [x.strip() for x in mapping.split(",")]
505        return {
506            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
507            for name_type_str in col_name_type_strs
508        }
509    elif isinstance(mapping, list):
510        return {x.strip(): None for x in mapping}
511
512    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
515def flatten_schema(
516    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
517) -> t.List[t.List[str]]:
518    tables = []
519    keys = keys or []
520    depth = dict_depth(schema) - 1 if depth is None else depth
521
522    for k, v in schema.items():
523        if depth == 1 or not isinstance(v, dict):
524            tables.append(keys + [k])
525        elif depth >= 2:
526            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
527
528    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
531def nested_get(
532    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
533) -> t.Optional[t.Any]:
534    """
535    Get a value for a nested dictionary.
536
537    Args:
538        d: the dictionary to search.
539        *path: tuples of (name, key), where:
540            `key` is the key in the dictionary to get.
541            `name` is a string to use in the error if `key` isn't found.
542
543    Returns:
544        The value or None if it doesn't exist.
545    """
546    for name, key in path:
547        d = d.get(key)  # type: ignore
548        if d is None:
549            if raise_on_missing:
550                name = "table" if name == "this" else name
551                raise ValueError(f"Unknown {name}: {key}")
552            return None
553
554    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
557def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
558    """
559    In-place set a value for a nested dictionary
560
561    Example:
562        >>> nested_set({}, ["top_key", "second_key"], "value")
563        {'top_key': {'second_key': 'value'}}
564
565        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
566        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
567
568    Args:
569        d: dictionary to update.
570        keys: the keys that makeup the path to `value`.
571        value: the value to set in the dictionary for the given key path.
572
573    Returns:
574        The (possibly) updated dictionary.
575    """
576    if not keys:
577        return d
578
579    if len(keys) == 1:
580        d[keys[0]] = value
581        return d
582
583    subd = d
584    for key in keys[:-1]:
585        if key not in subd:
586            subd = subd.setdefault(key, {})
587        else:
588            subd = subd[key]
589
590    subd[keys[-1]] = value
591    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.