diff options
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r-- | sqlglot/schema.py | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py new file mode 100644 index 0000000..c916330 --- /dev/null +++ b/sqlglot/schema.py @@ -0,0 +1,297 @@ +import abc + +from sqlglot import expressions as exp +from sqlglot.errors import OptimizeError +from sqlglot.helper import csv_reader + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @abc.abstractmethod + def add_table(self, table, column_mapping=None): + """ + Register or update a table. Some implementing classes may require column information to also be provided + + Args: + table (sqlglot.expressions.Table|str): Table expression instance or string representing the table + column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + """ + + @abc.abstractmethod + def column_names(self, table, only_visible=False): + """ + Get the column names for a table. + Args: + table (sqlglot.expressions.Table): Table expression instance + only_visible (bool): Whether to include invisible columns + Returns: + list[str]: list of column names + """ + + @abc.abstractmethod + def get_column_type(self, table, column): + """ + Get the exp.DataType type of a column in the schema. + + Args: + table (sqlglot.expressions.Table): The source table. + column (sqlglot.expressions.Column): The target column. + Returns: + sqlglot.expressions.DataType.Type: The resulting column type. + """ + + +class MappingSchema(Schema): + """ + Schema based on a nested mapping. + + Args: + schema (dict): Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + 4. None - Tables will be added later + visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect (str): The dialect to be used for custom type mappings. + """ + + def __init__(self, schema=None, visible=None, dialect=None): + self.schema = schema or {} + self.visible = visible + self.dialect = dialect + self._type_mapping_cache = {} + self.supported_table_args = [] + self.forbidden_table_args = set() + if self.schema: + self._initialize_supported_args() + + @classmethod + def from_mapping_schema(cls, mapping_schema): + return MappingSchema( + schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect + ) + + def copy(self, **kwargs): + return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) + + def add_table(self, table, column_mapping=None): + """ + Register or update a table. Updates are only performed if a new column mapping is provided. + + Args: + table (sqlglot.expressions.Table|str): Table expression instance or string representing the table + column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + """ + table = exp.to_table(table) + self._validate_table(table) + column_mapping = ensure_column_mapping(column_mapping) + table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] + existing_column_mapping = _nested_get( + self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False + ) + if existing_column_mapping and not column_mapping: + return + _nested_set( + self.schema, + [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], + column_mapping, + ) + self._initialize_supported_args() + + def _get_table_args_from_table(self, table): + if table.args.get("catalog") is not None: + return "catalog", "db", "this" + if table.args.get("db") is not None: + return "db", "this" + return ("this",) + + def _validate_table(self, table): + if not self.supported_table_args and isinstance(table, exp.Table): + return + for forbidden in self.forbidden_table_args: + if table.text(forbidden): + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + for expected in self.supported_table_args: + if not table.text(expected): + raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ") + + def column_names(self, table, only_visible=False): + table = exp.to_table(table) + if not isinstance(table.this, exp.Identifier): + return fs_get(table) + + args = tuple(table.text(p) for p in self.supported_table_args) + + for forbidden in self.forbidden_table_args: + if table.text(forbidden): + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + + columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + if not only_visible or not self.visible: + return columns + + visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) + return [col for col in columns if col in visible] + + def get_column_type(self, table, column): + try: + schema_type = self.schema.get(table.name, {}).get(column.name).upper() + return self._convert_type(schema_type) + except: + raise OptimizeError(f"Failed to get type for column {column.sql()}") + + def _convert_type(self, schema_type): + """ + Convert a type represented as a string to the corresponding exp.DataType.Type object. + Args: + schema_type (str): The type we want to convert. + Returns: + sqlglot.expressions.DataType.Type: The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + try: + self._type_mapping_cache[schema_type] = exp.maybe_parse( + schema_type, into=exp.DataType, dialect=self.dialect + ).this + except AttributeError: + raise OptimizeError(f"Failed to convert type {schema_type}") + + return self._type_mapping_cache[schema_type] + + def _initialize_supported_args(self): + if not self.supported_table_args: + depth = _dict_depth(self.schema) + + all_args = ["this", "db", "catalog"] + if not depth or depth == 1: # {} + self.supported_table_args = [] + elif 2 <= depth <= 4: + self.supported_table_args = tuple(reversed(all_args[: depth - 1])) + else: + raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + + self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args) + + +def ensure_schema(schema): + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema) + + +def ensure_column_mapping(mapping): + if isinstance(mapping, dict): + return mapping + elif isinstance(mapping, str): + col_name_type_strs = [x.strip() for x in mapping.split(",")] + return { + name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() + for name_type_str in col_name_type_strs + } + # Check if mapping looks like a DataFrame StructType + elif hasattr(mapping, "simpleString"): + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} + elif isinstance(mapping, list): + return {x.strip(): None for x in mapping} + elif mapping is None: + return {} + raise ValueError(f"Invalid mapping provided: {type(mapping)}") + + +def fs_get(table): + name = table.this.name + + if name.upper() == "READ_CSV": + with csv_reader(table) as reader: + return next(reader) + + raise ValueError(f"Cannot read schema for {table}") + + +def _nested_get(d, *path, raise_on_missing=True): + """ + Get a value for a nested dictionary. + + Args: + d (dict): dictionary + *path (tuple[str, str]): tuples of (name, key) + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + + Returns: + The value or None if it doesn't exist + """ + for name, key in path: + d = d.get(key) + if d is None: + if raise_on_missing: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}") + return None + return d + + +def _nested_set(d, keys, value): + """ + In-place set a value for a nested dictionary + + Ex: + >>> _nested_set({}, ["top_key", "second_key"], "value") + {'top_key': {'second_key': 'value'}} + >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") + {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} + + d (dict): dictionary + keys (Iterable[str]): ordered iterable of keys that makeup path to value + value (Any): The value to set in the dictionary for the given key path + """ + if not keys: + return + if len(keys) == 1: + d[keys[0]] = value + return + subd = d + for key in keys[:-1]: + if key not in subd: + subd = subd.setdefault(key, {}) + else: + subd = subd[key] + subd[keys[-1]] = value + return d + + +def _dict_depth(d): + """ + Get the nesting depth of a dictionary. + + For example: + >>> _dict_depth(None) + 0 + >>> _dict_depth({}) + 1 + >>> _dict_depth({"a": "b"}) + 1 + >>> _dict_depth({"a": {}}) + 2 + >>> _dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + _dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 |