diff options
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r-- | sqlglot/schema.py | 45 |
1 files changed, 43 insertions, 2 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index d9a4004..a0d69a7 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc import typing as t +import sqlglot from sqlglot import expressions as exp from sqlglot.errors import SchemaError from sqlglot.helper import dict_depth @@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): visible: t.Optional[t.Dict] = None, dialect: t.Optional[str] = None, ) -> None: - super().__init__(schema) - self.visible = visible or {} self.dialect = dialect + self.visible = visible or {} self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + super().__init__(self._normalize(schema or {})) @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: @@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): } ) + def _normalize(self, schema: t.Dict) -> t.Dict: + """ + Converts all identifiers in the schema into lowercase, unless they're quoted. + + Args: + schema: the schema to normalize. + + Returns: + The normalized schema mapping. + """ + flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) + + normalized_mapping: t.Dict = {} + for keys in flattened_schema: + columns = _nested_get(schema, *zip(keys, keys)) + assert columns is not None + + normalized_keys = [self._normalize_name(key) for key in keys] + for column_name, column_type in columns.items(): + _nested_set( + normalized_mapping, + normalized_keys + [self._normalize_name(column_name)], + column_type, + ) + + return normalized_mapping + def add_table( self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None ) -> None: @@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): ) self.mapping_trie = self._build_trie(self.mapping) + def _normalize_name(self, name: str) -> str: + try: + identifier: t.Optional[exp.Expression] = sqlglot.parse_one( + name, read=self.dialect, into=exp.Identifier + ) + except: + identifier = exp.to_identifier(name) + assert isinstance(identifier, exp.Identifier) + + if identifier.quoted: + return identifier.name + return identifier.name.lower() + def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 |