diff options
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r-- | sqlglot/schema.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f6f303b..8a264a2 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -47,7 +47,7 @@ class Schema(abc.ABC): """ @abc.abstractmethod - def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: + def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: """ Get the :class:`sqlglot.exp.DataType` type of a column in the schema. @@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): super().__init__(schema) self.visible = visible or {} self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = { - "STR": exp.DataType.Type.TEXT, + self._type_mapping_cache: t.Dict[str, exp.DataType] = { + "STR": exp.DataType.build("text"), } @classmethod @@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): visible = self._nested_get(self.table_parts(table_), self.visible) return [col for col in schema if col in visible] # type: ignore - def get_column_type( - self, table: exp.Table | str, column: exp.Column | str - ) -> exp.DataType.Type: + def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: column_name = column if isinstance(column, str) else column.name table_ = exp.to_table(table) if table_: - table_schema = self.find(table_) - schema_type = table_schema.get(column_name).upper() # type: ignore - return self._convert_type(schema_type) + table_schema = self.find(table_, raise_on_missing=False) + if table_schema: + schema_type = table_schema.get(column_name).upper() # type: ignore + return self._convert_type(schema_type) + return exp.DataType(this=exp.DataType.Type.UNKNOWN) raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type: str) -> exp.DataType.Type: + def _convert_type(self, schema_type: str) -> exp.DataType: """ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. @@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) if expression is None: raise ValueError(f"Could not parse {schema_type}") - self._type_mapping_cache[schema_type] = expression.this + self._type_mapping_cache[schema_type] = expression # type: ignore except AttributeError: raise SchemaError(f"Failed to convert type {schema_type}") |