diff options
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r-- | sqlglot/schema.py | 79 |
1 files changed, 62 insertions, 17 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 12cf0b1..7a3c88b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -31,14 +31,19 @@ class Schema(abc.ABC): table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None, dialect: DialectType = None, + normalize: t.Optional[bool] = None, + match_depth: bool = True, ) -> None: """ Register or update a table. Some implementing classes may require column information to also be provided. + The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. Args: table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + match_depth: whether to enforce that the table must match the schema's depth or not. """ @abc.abstractmethod @@ -47,6 +52,7 @@ class Schema(abc.ABC): table: exp.Table | str, only_visible: bool = False, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> t.List[str]: """ Get the column names for a table. @@ -55,6 +61,7 @@ class Schema(abc.ABC): table: the `Table` expression instance. only_visible: whether to include invisible columns. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. Returns: The list of column names. @@ -66,6 +73,7 @@ class Schema(abc.ABC): table: exp.Table | str, column: exp.Column, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> exp.DataType: """ Get the `sqlglot.exp.DataType` type of a column in the schema. @@ -74,6 +82,7 @@ class Schema(abc.ABC): table: the source table. column: the target column. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. Returns: The resulting column type. @@ -99,7 +108,7 @@ class AbstractMappingSchema(t.Generic[T]): ) -> None: self.mapping = mapping or {} self.mapping_trie = new_trie( - tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) + tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) ) self._supported_table_args: t.Tuple[str, ...] = tuple() @@ -107,13 +116,13 @@ class AbstractMappingSchema(t.Generic[T]): def empty(self) -> bool: return not self.mapping - def _depth(self) -> int: + def depth(self) -> int: return dict_depth(self.mapping) @property def supported_table_args(self) -> t.Tuple[str, ...]: if not self._supported_table_args and self.mapping: - depth = self._depth() + depth = self.depth() if not depth: # None self._supported_table_args = tuple() @@ -191,6 +200,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): self.visible = visible or {} self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + self._depth = 0 super().__init__(self._normalize(schema or {})) @@ -200,6 +210,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): schema=mapping_schema.mapping, visible=mapping_schema.visible, dialect=mapping_schema.dialect, + normalize=mapping_schema.normalize, ) def copy(self, **kwargs) -> MappingSchema: @@ -208,6 +219,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): "schema": self.mapping.copy(), "visible": self.visible.copy(), "dialect": self.dialect, + "normalize": self.normalize, **kwargs, } ) @@ -217,19 +229,30 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None, dialect: DialectType = None, + normalize: t.Optional[bool] = None, + match_depth: bool = True, ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. + The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. Args: table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + match_depth: whether to enforce that the table must match the schema's depth or not. """ - normalized_table = self._normalize_table(table, dialect=dialect) + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) + + if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): + raise SchemaError( + f"Table {normalized_table.sql(dialect=self.dialect)} must match the " + f"schema's nesting level: {self.depth()}." + ) normalized_column_mapping = { - self._normalize_name(key, dialect=dialect): value + self._normalize_name(key, dialect=dialect, normalize=normalize): value for key, value in ensure_column_mapping(column_mapping).items() } @@ -247,8 +270,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): table: exp.Table | str, only_visible: bool = False, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> t.List[str]: - normalized_table = self._normalize_table(table, dialect=dialect) + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) schema = self.find(normalized_table) if schema is None: @@ -265,11 +289,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): table: exp.Table | str, column: exp.Column, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> exp.DataType: - normalized_table = self._normalize_table(table, dialect=dialect) + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) normalized_column_name = self._normalize_name( - column if isinstance(column, str) else column.this, dialect=dialect + column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize ) table_schema = self.find(normalized_table, raise_on_missing=False) @@ -293,12 +318,16 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): Returns: The normalized schema mapping. """ + normalized_mapping: t.Dict = {} flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) - normalized_mapping: t.Dict = {} for keys in flattened_schema: columns = nested_get(schema, *zip(keys, keys)) - assert columns is not None + + if not isinstance(columns, dict): + raise SchemaError( + f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." + ) normalized_keys = [ self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys @@ -312,7 +341,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return normalized_mapping - def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: + def _normalize_table( + self, + table: exp.Table | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.Table: normalized_table = exp.maybe_parse( table, into=exp.Table, dialect=dialect or self.dialect, copy=True ) @@ -322,15 +356,24 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): if isinstance(value, (str, exp.Identifier)): normalized_table.set( arg, - exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)), + exp.to_identifier( + self._normalize_name( + value, dialect=dialect, is_table=True, normalize=normalize + ) + ), ) return normalized_table def _normalize_name( - self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False + self, + name: str | exp.Identifier, + dialect: DialectType = None, + is_table: bool = False, + normalize: t.Optional[bool] = None, ) -> str: dialect = dialect or self.dialect + normalize = self.normalize if normalize is None else normalize try: identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) @@ -338,16 +381,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return name if isinstance(name, str) else name.name name = identifier.name - if not self.normalize: + if not normalize: return name # This can be useful for normalize_identifier identifier.meta["is_table"] = is_table return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name - def _depth(self) -> int: - # The columns themselves are a mapping, but we don't want to include those - return super()._depth() - 1 + def depth(self) -> int: + if not self.empty and not self._depth: + # The columns themselves are a mapping, but we don't want to include those + self._depth = super().depth() - 1 + return self._depth def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: """ |