summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py51
1 files changed, 44 insertions, 7 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f0b279b..778378c 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot._typing import T
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
@@ -71,7 +70,7 @@ class Schema(abc.ABC):
def get_column_type(
self,
table: exp.Table | str,
- column: exp.Column,
+ column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
@@ -88,6 +87,28 @@ class Schema(abc.ABC):
The resulting column type.
"""
+ def has_column(
+ self,
+ table: exp.Table | str,
+ column: exp.Column | str,
+ dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ ) -> bool:
+ """
+ Returns whether or not `column` appears in `table`'s schema.
+
+ Args:
+ table: the source table.
+ column: the target column.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
+
+ Returns:
+ True if the column appears in the schema, False otherwise.
+ """
+ name = column if isinstance(column, str) else column.name
+ return name in self.column_names(table, dialect=dialect, normalize=normalize)
+
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@@ -101,7 +122,7 @@ class Schema(abc.ABC):
return True
-class AbstractMappingSchema(t.Generic[T]):
+class AbstractMappingSchema:
def __init__(
self,
mapping: t.Optional[t.Dict] = None,
@@ -140,7 +161,7 @@ class AbstractMappingSchema(t.Generic[T]):
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
- ) -> t.Optional[T]:
+ ) -> t.Optional[t.Any]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
@@ -170,7 +191,7 @@ class AbstractMappingSchema(t.Generic[T]):
)
-class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
+class MappingSchema(AbstractMappingSchema, Schema):
"""
Schema based on a nested mapping.
@@ -287,7 +308,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
def get_column_type(
self,
table: exp.Table | str,
- column: exp.Column,
+ column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
@@ -304,10 +325,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
- return self._to_data_type(column_type.upper(), dialect=dialect)
+ return self._to_data_type(column_type, dialect=dialect)
return exp.DataType.build("unknown")
+ def has_column(
+ self,
+ table: exp.Table | str,
+ column: exp.Column | str,
+ dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ ) -> bool:
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
+
+ normalized_column_name = self._normalize_name(
+ column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
+ )
+
+ table_schema = self.find(normalized_table, raise_on_missing=False)
+ return normalized_column_name in table_schema if table_schema else False
+
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Normalizes all identifiers in the schema.