diff options
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r-- | sqlglot/schema.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f73adee..b8560a1 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -6,10 +6,10 @@ import typing as t import sqlglot from sqlglot import expressions as exp from sqlglot._typing import T -from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE +from sqlglot.dialects.dialect import Dialect from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth -from sqlglot.trie import in_trie, new_trie +from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: from sqlglot.dataframe.sql.types import StructType @@ -135,10 +135,10 @@ class AbstractMappingSchema(t.Generic[T]): parts = self.table_parts(table)[0 : len(self.supported_table_args)] value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) - if value == 0: + if value == TrieResult.FAILED: return None - if value == 1: + if value == TrieResult.PREFIX: possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) if len(possibilities) == 1: @@ -289,7 +289,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): def _normalize(self, schema: t.Dict) -> t.Dict: """ - Converts all identifiers in the schema into lowercase, unless they're quoted. + Normalizes all identifiers in the schema. Args: schema: the schema to normalize. @@ -304,7 +304,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): columns = nested_get(schema, *zip(keys, keys)) assert columns is not None - normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] + normalized_keys = [ + self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys + ] for column_name, column_type in columns.items(): nested_set( normalized_mapping, @@ -321,12 +323,15 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): value = normalized_table.args.get(arg) if isinstance(value, (str, exp.Identifier)): normalized_table.set( - arg, exp.to_identifier(self._normalize_name(value, dialect=dialect)) + arg, + exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)), ) return normalized_table - def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: + def _normalize_name( + self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False + ) -> str: dialect = dialect or self.dialect try: @@ -335,11 +340,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return name if isinstance(name, str) else name.name name = identifier.name - - if not self.normalize or identifier.quoted: + if not self.normalize: return name - return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() + # This can be useful for normalize_identifier + identifier.meta["is_table"] = is_table + return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those |