diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/schema.py | 51 |
1 files changed, 44 insertions, 7 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f0b279b..778378c 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,7 +5,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot._typing import T from sqlglot.dialects.dialect import Dialect from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth @@ -71,7 +70,7 @@ class Schema(abc.ABC): def get_column_type( self, table: exp.Table | str, - column: exp.Column, + column: exp.Column | str, dialect: DialectType = None, normalize: t.Optional[bool] = None, ) -> exp.DataType: @@ -88,6 +87,28 @@ class Schema(abc.ABC): The resulting column type. """ + def has_column( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> bool: + """ + Returns whether or not `column` appears in `table`'s schema. + + Args: + table: the source table. + column: the target column. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + + Returns: + True if the column appears in the schema, False otherwise. + """ + name = column if isinstance(column, str) else column.name + return name in self.column_names(table, dialect=dialect, normalize=normalize) + @property @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: @@ -101,7 +122,7 @@ class Schema(abc.ABC): return True -class AbstractMappingSchema(t.Generic[T]): +class AbstractMappingSchema: def __init__( self, mapping: t.Optional[t.Dict] = None, @@ -140,7 +161,7 @@ class AbstractMappingSchema(t.Generic[T]): def find( self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True - ) -> t.Optional[T]: + ) -> t.Optional[t.Any]: parts = self.table_parts(table)[0 : len(self.supported_table_args)] value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) @@ -170,7 +191,7 @@ class AbstractMappingSchema(t.Generic[T]): ) -class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): +class MappingSchema(AbstractMappingSchema, Schema): """ Schema based on a nested mapping. @@ -287,7 +308,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): def get_column_type( self, table: exp.Table | str, - column: exp.Column, + column: exp.Column | str, dialect: DialectType = None, normalize: t.Optional[bool] = None, ) -> exp.DataType: @@ -304,10 +325,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): if isinstance(column_type, exp.DataType): return column_type elif isinstance(column_type, str): - return self._to_data_type(column_type.upper(), dialect=dialect) + return self._to_data_type(column_type, dialect=dialect) return exp.DataType.build("unknown") + def has_column( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> bool: + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) + + normalized_column_name = self._normalize_name( + column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize + ) + + table_schema = self.find(normalized_table, raise_on_missing=False) + return normalized_column_name in table_schema if table_schema else False + def _normalize(self, schema: t.Dict) -> t.Dict: """ Normalizes all identifiers in the schema. |