diff options
Diffstat (limited to 'sqlglot/optimizer/schema.py')
-rw-r--r-- | sqlglot/optimizer/schema.py | 63 |
1 files changed, 58 insertions, 5 deletions
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 1bbd86a..d7743c9 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -9,16 +9,28 @@ class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def column_names(self, table): + def column_names(self, table, only_visible=False): """ Get the column names for a table. - Args: table (sqlglot.expressions.Table): Table expression instance + only_visible (bool): Whether to include invisible columns Returns: list[str]: list of column names """ + @abc.abstractmethod + def get_column_type(self, table, column): + """ + Get the exp.DataType type of a column in the schema. + + Args: + table (sqlglot.expressions.Table): The source table. + column (sqlglot.expressions.Column): The target column. + Returns: + sqlglot.expressions.DataType.Type: The resulting column type. + """ + class MappingSchema(Schema): """ @@ -29,10 +41,19 @@ class MappingSchema(Schema): 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} + visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema): + def __init__(self, schema, visible=None, dialect=None): self.schema = schema + self.visible = visible + self.dialect = dialect + self._type_mapping_cache = {} depth = _dict_depth(schema) @@ -49,7 +70,7 @@ class MappingSchema(Schema): self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) - def column_names(self, table): + def column_names(self, table, only_visible=False): if not isinstance(table.this, exp.Identifier): return fs_get(table) @@ -58,7 +79,39 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + if not only_visible or not self.visible: + return columns + + visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) + return [col for col in columns if col in visible] + + def get_column_type(self, table, column): + try: + schema_type = self.schema.get(table.name, {}).get(column.name).upper() + return self._convert_type(schema_type) + except: + raise OptimizeError(f"Failed to get type for column {column.sql()}") + + def _convert_type(self, schema_type): + """ + Convert a type represented as a string to the corresponding exp.DataType.Type object. + + Args: + schema_type (str): The type we want to convert. + Returns: + sqlglot.expressions.DataType.Type: The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + try: + self._type_mapping_cache[schema_type] = exp.maybe_parse( + schema_type, into=exp.DataType, dialect=self.dialect + ).this + except AttributeError: + raise OptimizeError(f"Failed to convert type {schema_type}") + + return self._type_mapping_cache[schema_type] def ensure_schema(schema): |