summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py45
1 files changed, 43 insertions, 2 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index d9a4004..a0d69a7 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import abc
import typing as t
+import sqlglot
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
@@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
- super().__init__(schema)
- self.visible = visible or {}
self.dialect = dialect
+ self.visible = visible or {}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
+ super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
@@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
}
)
+ def _normalize(self, schema: t.Dict) -> t.Dict:
+ """
+ Converts all identifiers in the schema into lowercase, unless they're quoted.
+
+ Args:
+ schema: the schema to normalize.
+
+ Returns:
+ The normalized schema mapping.
+ """
+ flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
+
+ normalized_mapping: t.Dict = {}
+ for keys in flattened_schema:
+ columns = _nested_get(schema, *zip(keys, keys))
+ assert columns is not None
+
+ normalized_keys = [self._normalize_name(key) for key in keys]
+ for column_name, column_type in columns.items():
+ _nested_set(
+ normalized_mapping,
+ normalized_keys + [self._normalize_name(column_name)],
+ column_type,
+ )
+
+ return normalized_mapping
+
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
)
self.mapping_trie = self._build_trie(self.mapping)
+ def _normalize_name(self, name: str) -> str:
+ try:
+ identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
+ name, read=self.dialect, into=exp.Identifier
+ )
+ except:
+ identifier = exp.to_identifier(name)
+ assert isinstance(identifier, exp.Identifier)
+
+ if identifier.quoted:
+ return identifier.name
+ return identifier.name.lower()
+
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1