diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/schema.py | 298 |
1 files changed, 192 insertions, 106 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index c916330..fcf7291 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -1,44 +1,60 @@ +from __future__ import annotations + import abc +import typing as t from sqlglot import expressions as exp -from sqlglot.errors import OptimizeError +from sqlglot.errors import SchemaError from sqlglot.helper import csv_reader +from sqlglot.trie import in_trie, new_trie + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.types import StructType + + ColumnMapping = t.Union[t.Dict, str, StructType, t.List] + +TABLE_ARGS = ("this", "db", "catalog") class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def add_table(self, table, column_mapping=None): + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: """ - Register or update a table. Some implementing classes may require column information to also be provided + Register or update a table. Some implementing classes may require column information to also be provided. Args: - table (sqlglot.expressions.Table|str): Table expression instance or string representing the table - column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + table: table expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. """ @abc.abstractmethod - def column_names(self, table, only_visible=False): + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: """ Get the column names for a table. + Args: - table (sqlglot.expressions.Table): Table expression instance - only_visible (bool): Whether to include invisible columns + table: the `Table` expression instance. + only_visible: whether to include invisible columns. + Returns: - list[str]: list of column names + The list of column names. """ @abc.abstractmethod - def get_column_type(self, table, column): + def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: """ - Get the exp.DataType type of a column in the schema. + Get the :class:`sqlglot.exp.DataType` type of a column in the schema. Args: - table (sqlglot.expressions.Table): The source table. - column (sqlglot.expressions.Column): The target column. + table: the source table. + column: the target column. + Returns: - sqlglot.expressions.DataType.Type: The resulting column type. + The resulting column type. """ @@ -60,132 +76,179 @@ class MappingSchema(Schema): dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema=None, visible=None, dialect=None): + def __init__( + self, + schema: t.Optional[t.Dict] = None, + visible: t.Optional[t.Dict] = None, + dialect: t.Optional[str] = None, + ) -> None: self.schema = schema or {} - self.visible = visible + self.visible = visible or {} + self.schema_trie = self._build_trie(self.schema) self.dialect = dialect - self._type_mapping_cache = {} - self.supported_table_args = [] - self.forbidden_table_args = set() - if self.schema: - self._initialize_supported_args() + self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {} + self._supported_table_args: t.Tuple[str, ...] = tuple() @classmethod - def from_mapping_schema(cls, mapping_schema): + def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: + return MappingSchema( + schema=mapping_schema.schema, + visible=mapping_schema.visible, + dialect=mapping_schema.dialect, + ) + + def copy(self, **kwargs) -> MappingSchema: return MappingSchema( - schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect + **{ # type: ignore + "schema": self.schema.copy(), + "visible": self.visible.copy(), + "dialect": self.dialect, + **kwargs, + } ) - def copy(self, **kwargs): - return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) + @property + def supported_table_args(self): + if not self._supported_table_args and self.schema: + depth = _dict_depth(self.schema) - def add_table(self, table, column_mapping=None): + if not depth or depth == 1: # {} + self._supported_table_args = tuple() + elif 2 <= depth <= 4: + self._supported_table_args = TABLE_ARGS[: depth - 1] + else: + raise SchemaError(f"Invalid schema shape. Depth: {depth}") + + return self._supported_table_args + + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. Args: - table (sqlglot.expressions.Table|str): Table expression instance or string representing the table - column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. """ - table = exp.to_table(table) - self._validate_table(table) + table_ = self._ensure_table(table) column_mapping = ensure_column_mapping(column_mapping) - table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] - existing_column_mapping = _nested_get( - self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False - ) - if existing_column_mapping and not column_mapping: + schema = self.find_schema(table_, raise_on_missing=False) + + if schema and not column_mapping: return + _nested_set( self.schema, - [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], + list(reversed(self.table_parts(table_))), column_mapping, ) - self._initialize_supported_args() + self.schema_trie = self._build_trie(self.schema) - def _get_table_args_from_table(self, table): - if table.args.get("catalog") is not None: - return "catalog", "db", "this" - if table.args.get("db") is not None: - return "db", "this" - return ("this",) + def _ensure_table(self, table: exp.Table | str) -> exp.Table: + table_ = exp.to_table(table) - def _validate_table(self, table): - if not self.supported_table_args and isinstance(table, exp.Table): - return - for forbidden in self.forbidden_table_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - for expected in self.supported_table_args: - if not table.text(expected): - raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ") + if not table_: + raise SchemaError(f"Not a valid table '{table}'") + + return table_ + + def table_parts(self, table: exp.Table) -> t.List[str]: + return [table.text(part) for part in TABLE_ARGS if table.text(part)] + + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + table_ = self._ensure_table(table) - def column_names(self, table, only_visible=False): - table = exp.to_table(table) - if not isinstance(table.this, exp.Identifier): - return fs_get(table) + if not isinstance(table_.this, exp.Identifier): + return fs_get(table) # type: ignore - args = tuple(table.text(p) for p in self.supported_table_args) + schema = self.find_schema(table_) - for forbidden in self.forbidden_table_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + if schema is None: + raise SchemaError(f"Could not find table schema {table}") - columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) if not only_visible or not self.visible: - return columns + return list(schema) - visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) - return [col for col in columns if col in visible] + visible = self._nested_get(self.table_parts(table_), self.visible) + return [col for col in schema if col in visible] # type: ignore - def get_column_type(self, table, column): - try: - schema_type = self.schema.get(table.name, {}).get(column.name).upper() + def find_schema( + self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True + ) -> t.Optional[t.Dict[str, str]]: + parts = self.table_parts(table)[0 : len(self.supported_table_args)] + value, trie = in_trie(self.schema_trie if trie is None else trie, parts) + + if value == 0: + if raise_on_missing: + raise SchemaError(f"Cannot find schema for {table}.") + else: + return None + elif value == 1: + possibilities = flatten_schema(trie) + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + message = ", ".join(".".join(parts) for parts in possibilities) + if raise_on_missing: + raise SchemaError(f"Ambiguous schema for {table}: {message}.") + return None + + return self._nested_get(parts, raise_on_missing=raise_on_missing) + + def get_column_type( + self, table: exp.Table | str, column: exp.Column | str + ) -> exp.DataType.Type: + column_name = column if isinstance(column, str) else column.name + table_ = exp.to_table(table) + if table_: + table_schema = self.find_schema(table_) + schema_type = table_schema.get(column_name).upper() # type: ignore return self._convert_type(schema_type) - except: - raise OptimizeError(f"Failed to get type for column {column.sql()}") + raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type): + def _convert_type(self, schema_type: str) -> exp.DataType.Type: """ - Convert a type represented as a string to the corresponding exp.DataType.Type object. + Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. + Args: - schema_type (str): The type we want to convert. + schema_type: the type we want to convert. + Returns: - sqlglot.expressions.DataType.Type: The resulting expression type. + The resulting expression type. """ if schema_type not in self._type_mapping_cache: try: - self._type_mapping_cache[schema_type] = exp.maybe_parse( - schema_type, into=exp.DataType, dialect=self.dialect - ).this + expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) + if expression is None: + raise ValueError(f"Could not parse {schema_type}") + self._type_mapping_cache[schema_type] = expression.this except AttributeError: - raise OptimizeError(f"Failed to convert type {schema_type}") + raise SchemaError(f"Failed to convert type {schema_type}") return self._type_mapping_cache[schema_type] - def _initialize_supported_args(self): - if not self.supported_table_args: - depth = _dict_depth(self.schema) - - all_args = ["this", "db", "catalog"] - if not depth or depth == 1: # {} - self.supported_table_args = [] - elif 2 <= depth <= 4: - self.supported_table_args = tuple(reversed(all_args[: depth - 1])) - else: - raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + def _build_trie(self, schema: t.Dict): + return new_trie(tuple(reversed(t)) for t in flatten_schema(schema)) - self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args) + def _nested_get( + self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True + ) -> t.Optional[t.Any]: + return _nested_get( + d or self.schema, + *zip(self.supported_table_args, reversed(parts)), + raise_on_missing=raise_on_missing, + ) -def ensure_schema(schema): +def ensure_schema(schema: t.Any) -> Schema: if isinstance(schema, Schema): return schema return MappingSchema(schema) -def ensure_column_mapping(mapping): +def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): if isinstance(mapping, dict): return mapping elif isinstance(mapping, str): @@ -196,7 +259,7 @@ def ensure_column_mapping(mapping): } # Check if mapping looks like a DataFrame StructType elif hasattr(mapping, "simpleString"): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore elif isinstance(mapping, list): return {x.strip(): None for x in mapping} elif mapping is None: @@ -204,7 +267,20 @@ def ensure_column_mapping(mapping): raise ValueError(f"Invalid mapping provided: {type(mapping)}") -def fs_get(table): +def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]: + tables = [] + keys = keys or [] + depth = _dict_depth(schema) + + for k, v in schema.items(): + if depth >= 3: + tables.extend(flatten_schema(v, keys + [k])) + elif depth == 2: + tables.append(keys + [k]) + return tables + + +def fs_get(table: exp.Table) -> t.List[str]: name = table.this.name if name.upper() == "READ_CSV": @@ -214,21 +290,23 @@ def fs_get(table): raise ValueError(f"Cannot read schema for {table}") -def _nested_get(d, *path, raise_on_missing=True): +def _nested_get( + d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True +) -> t.Optional[t.Any]: """ Get a value for a nested dictionary. Args: - d (dict): dictionary - *path (tuple[str, str]): tuples of (name, key) + d: the dictionary to search. + *path: tuples of (name, key), where: `key` is the key in the dictionary to get. `name` is a string to use in the error if `key` isn't found. Returns: - The value or None if it doesn't exist + The value or None if it doesn't exist. """ for name, key in path: - d = d.get(key) + d = d.get(key) # type: ignore if d is None: if raise_on_missing: name = "table" if name == "this" else name @@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True): return d -def _nested_set(d, keys, value): +def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary - Ex: + Example: >>> _nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}} + >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} - d (dict): dictionary - keys (Iterable[str]): ordered iterable of keys that makeup path to value - value (Any): The value to set in the dictionary for the given key path + Args: + d: dictionary to update. + keys: the keys that makeup the path to `value`. + value: the value to set in the dictionary for the given key path. + + Returns: + The (possibly) updated dictionary. """ if not keys: - return + return d + if len(keys) == 1: d[keys[0]] = value - return + return d + subd = d for key in keys[:-1]: if key not in subd: subd = subd.setdefault(key, {}) else: subd = subd[key] + subd[keys[-1]] = value return d -def _dict_depth(d): +def _dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. |