  1from __future__ import annotations
  3import abc
  4import typing as t
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import Dialect
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import TrieResult, in_trie, new_trie
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 20TABLE_ARGS = ("this", "db", "catalog")
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 26    dialect: DialectType
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34    ) -> None:
 35        """
 36        Register or update a table. Some implementing classes may require column information to also be provided.
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42        """
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50    ) -> t.List[str]:
 51        """
 52        Get the column names for a table.
 54        Args:
 55            table: the `Table` expression instance.
 56            only_visible: whether to include invisible columns.
 57            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 59        Returns:
 60            The list of column names.
 61        """
 63    @abc.abstractmethod
 64    def get_column_type(
 65        self,
 66        table: exp.Table | str,
 67        column: exp.Column,
 68        dialect: DialectType = None,
 69    ) -> exp.DataType:
 70        """
 71        Get the `sqlglot.exp.DataType` type of a column in the schema.
 73        Args:
 74            table: the source table.
 75            column: the target column.
 76            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 78        Returns:
 79            The resulting column type.
 80        """
 82    @property
 83    @abc.abstractmethod
 84    def supported_table_args(self) -> t.Tuple[str, ...]:
 85        """
 86        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 87        """
 89    @property
 90    def empty(self) -> bool:
 91        """Returns whether or not the schema is empty."""
 92        return True
 95class AbstractMappingSchema(t.Generic[T]):
 96    def __init__(
 97        self,
 98        mapping: t.Optional[t.Dict] = None,
 99    ) -> None:
100        self.mapping = mapping or {}
101        self.mapping_trie = new_trie(
102            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
103        )
104        self._supported_table_args: t.Tuple[str, ...] = tuple()
106    @property
107    def empty(self) -> bool:
108        return not self.mapping
110    def _depth(self) -> int:
111        return dict_depth(self.mapping)
113    @property
114    def supported_table_args(self) -> t.Tuple[str, ...]:
115        if not self._supported_table_args and self.mapping:
116            depth = self._depth()
118            if not depth:  # None
119                self._supported_table_args = tuple()
120            elif 1 <= depth <= 3:
121                self._supported_table_args = TABLE_ARGS[:depth]
122            else:
123                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
125        return self._supported_table_args
127    def table_parts(self, table: exp.Table) -> t.List[str]:
128        if isinstance(table.this, exp.ReadCSV):
129            return [table.this.name]
130        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
132    def find(
133        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
134    ) -> t.Optional[T]:
135        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
136        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
138        if value == TrieResult.FAILED:
139            return None
141        if value == TrieResult.PREFIX:
142            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
144            if len(possibilities) == 1:
145                parts.extend(possibilities[0])
146            else:
147                message = ", ".join(".".join(parts) for parts in possibilities)
148                if raise_on_missing:
149                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
150                return None
152        return self.nested_get(parts, raise_on_missing=raise_on_missing)
154    def nested_get(
155        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
156    ) -> t.Optional[t.Any]:
157        return nested_get(
158            d or self.mapping,
159            *zip(self.supported_table_args, reversed(parts)),
160            raise_on_missing=raise_on_missing,
161        )
164class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
165    """
166    Schema based on a nested mapping.
168    Args:
169        schema: Mapping in one of the following forms:
170            1. {table: {col: type}}
171            2. {db: {table: {col: type}}}
172            3. {catalog: {db: {table: {col: type}}}}
173            4. None - Tables will be added later
174        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
175            are assumed to be visible. The nesting should mirror that of the schema:
176            1. {table: set(*cols)}}
177            2. {db: {table: set(*cols)}}}
178            3. {catalog: {db: {table: set(*cols)}}}}
179        dialect: The dialect to be used for custom type mappings & parsing string arguments.
180        normalize: Whether to normalize identifier names according to the given dialect or not.
181    """
183    def __init__(
184        self,
185        schema: t.Optional[t.Dict] = None,
186        visible: t.Optional[t.Dict] = None,
187        dialect: DialectType = None,
188        normalize: bool = True,
189    ) -> None:
190        self.dialect = dialect
191        self.visible = visible or {}
192        self.normalize = normalize
193        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
195        super().__init__(self._normalize(schema or {}))
197    @classmethod
198    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
199        return MappingSchema(
200            schema=mapping_schema.mapping,
201            visible=mapping_schema.visible,
202            dialect=mapping_schema.dialect,
203        )
205    def copy(self, **kwargs) -> MappingSchema:
206        return MappingSchema(
207            **{  # type: ignore
208                "schema": self.mapping.copy(),
209                "visible": self.visible.copy(),
210                "dialect": self.dialect,
211                **kwargs,
212            }
213        )
215    def add_table(
216        self,
217        table: exp.Table | str,
218        column_mapping: t.Optional[ColumnMapping] = None,
219        dialect: DialectType = None,
220    ) -> None:
221        """
222        Register or update a table. Updates are only performed if a new column mapping is provided.
224        Args:
225            table: the `Table` expression instance or string representing the table.
226            column_mapping: a column mapping that describes the structure of the table.
227            dialect: the SQL dialect that will be used to parse `table` if it's a string.
228        """
229        normalized_table = self._normalize_table(table, dialect=dialect)
231        normalized_column_mapping = {
232            self._normalize_name(key, dialect=dialect): value
233            for key, value in ensure_column_mapping(column_mapping).items()
234        }
236        schema = self.find(normalized_table, raise_on_missing=False)
237        if schema and not normalized_column_mapping:
238            return
240        parts = self.table_parts(normalized_table)
242        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
243        new_trie([parts], self.mapping_trie)
245    def column_names(
246        self,
247        table: exp.Table | str,
248        only_visible: bool = False,
249        dialect: DialectType = None,
250    ) -> t.List[str]:
251        normalized_table = self._normalize_table(table, dialect=dialect)
253        schema = self.find(normalized_table)
254        if schema is None:
255            return []
257        if not only_visible or not self.visible:
258            return list(schema)
260        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
261        return [col for col in schema if col in visible]
263    def get_column_type(
264        self,
265        table: exp.Table | str,
266        column: exp.Column,
267        dialect: DialectType = None,
268    ) -> exp.DataType:
269        normalized_table = self._normalize_table(table, dialect=dialect)
271        normalized_column_name = self._normalize_name(
272            column if isinstance(column, str) else column.this, dialect=dialect
273        )
275        table_schema = self.find(normalized_table, raise_on_missing=False)
276        if table_schema:
277            column_type = table_schema.get(normalized_column_name)
279            if isinstance(column_type, exp.DataType):
280                return column_type
281            elif isinstance(column_type, str):
282                return self._to_data_type(column_type.upper(), dialect=dialect)
284        return exp.DataType.build("unknown")
286    def _normalize(self, schema: t.Dict) -> t.Dict:
287        """
288        Normalizes all identifiers in the schema.
290        Args:
291            schema: the schema to normalize.
293        Returns:
294            The normalized schema mapping.
295        """
296        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
298        normalized_mapping: t.Dict = {}
299        for keys in flattened_schema:
300            columns = nested_get(schema, *zip(keys, keys))
301            assert columns is not None
303            normalized_keys = [
304                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
305            ]
306            for column_name, column_type in columns.items():
307                nested_set(
308                    normalized_mapping,
309                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
310                    column_type,
311                )
313        return normalized_mapping
315    def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
316        normalized_table = exp.maybe_parse(
317            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
318        )
320        for arg in TABLE_ARGS:
321            value = normalized_table.args.get(arg)
322            if isinstance(value, (str, exp.Identifier)):
323                normalized_table.set(
324                    arg,
325                    exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
326                )
328        return normalized_table
330    def _normalize_name(
331        self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
332    ) -> str:
333        dialect = dialect or self.dialect
335        try:
336            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
337        except ParseError:
338            return name if isinstance(name, str) else name.name
340        name = identifier.name
341        if not self.normalize:
342            return name
344        # This can be useful for normalize_identifier
345        identifier.meta["is_table"] = is_table
346        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
348    def _depth(self) -> int:
349        # The columns themselves are a mapping, but we don't want to include those
350        return super()._depth() - 1
352    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
353        """
354        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
356        Args:
357            schema_type: the type we want to convert.
358            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
360        Returns:
361            The resulting expression type.
362        """
363        if schema_type not in self._type_mapping_cache:
364            dialect = dialect or self.dialect
366            try:
367                expression = exp.DataType.build(schema_type, dialect=dialect)
368                self._type_mapping_cache[schema_type] = expression
369            except AttributeError:
370                in_dialect = f" in dialect {dialect}" if dialect else ""
371                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
373        return self._type_mapping_cache[schema_type]
376def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
377    if isinstance(schema, Schema):
378        return schema
380    return MappingSchema(schema, **kwargs)
383def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
384    if mapping is None:
385        return {}
386    elif isinstance(mapping, dict):
387        return mapping
388    elif isinstance(mapping, str):
389        col_name_type_strs = [x.strip() for x in mapping.split(",")]
390        return {
391            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
392            for name_type_str in col_name_type_strs
393        }
394    # Check if mapping looks like a DataFrame StructType
395    elif hasattr(mapping, "simpleString"):
396        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
397    elif isinstance(mapping, list):
398        return {x.strip(): None for x in mapping}
400    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
403def flatten_schema(
404    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
405) -> t.List[t.List[str]]:
406    tables = []
407    keys = keys or []
409    for k, v in schema.items():
410        if depth >= 2:
411            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
412        elif depth == 1:
413            tables.append(keys + [k])
415    return tables
418def nested_get(
419    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
420) -> t.Optional[t.Any]:
421    """
422    Get a value for a nested dictionary.
424    Args:
425        d: the dictionary to search.
426        *path: tuples of (name, key), where:
427            `key` is the key in the dictionary to get.
428            `name` is a string to use in the error if `key` isn't found.
430    Returns:
431        The value or None if it doesn't exist.
432    """
433    for name, key in path:
434        d = d.get(key)  # type: ignore
435        if d is None:
436            if raise_on_missing:
437                name = "table" if name == "this" else name
438                raise ValueError(f"Unknown {name}: {key}")
439            return None
441    return d
444def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
445    """
446    In-place set a value for a nested dictionary
448    Example:
449        >>> nested_set({}, ["top_key", "second_key"], "value")
450        {'top_key': {'second_key': 'value'}}
452        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
453        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
455    Args:
456        d: dictionary to update.
457        keys: the keys that makeup the path to `value`.
458        value: the value to set in the dictionary for the given key path.
460    Returns:
461        The (possibly) updated dictionary.
462    """
463    if not keys:
464        return d
466    if len(keys) == 1:
467        d[keys[0]] = value
468        return d
470    subd = d
471    for key in keys[:-1]:
472        if key not in subd:
473            subd = subd.setdefault(key, {})
474        else:
475            subd = subd[key]
477    subd[keys[-1]] = value
478    return d
