summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f73adee..b8560a1 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -6,10 +6,10 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot._typing import T
-from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
+from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
-from sqlglot.trie import in_trie, new_trie
+from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
@@ -135,10 +135,10 @@ class AbstractMappingSchema(t.Generic[T]):
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
- if value == 0:
+ if value == TrieResult.FAILED:
return None
- if value == 1:
+ if value == TrieResult.PREFIX:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
@@ -289,7 +289,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
- Converts all identifiers in the schema into lowercase, unless they're quoted.
+ Normalizes all identifiers in the schema.
Args:
schema: the schema to normalize.
@@ -304,7 +304,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
columns = nested_get(schema, *zip(keys, keys))
assert columns is not None
- normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
+ normalized_keys = [
+ self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
+ ]
for column_name, column_type in columns.items():
nested_set(
normalized_mapping,
@@ -321,12 +323,15 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(
- arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
+ arg,
+ exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
)
return normalized_table
- def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
+ def _normalize_name(
+ self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
+ ) -> str:
dialect = dialect or self.dialect
try:
@@ -335,11 +340,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return name if isinstance(name, str) else name.name
name = identifier.name
-
- if not self.normalize or identifier.quoted:
+ if not self.normalize:
return name
- return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
+ # This can be useful for normalize_identifier
+ identifier.meta["is_table"] = is_table
+ return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those