summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py156
1 files changed, 83 insertions, 73 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 8e39c7f..5d60eb9 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.errors import SchemaError
+from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
@@ -75,12 +75,11 @@ class AbstractMappingSchema(t.Generic[T]):
mapping: dict | None = None,
) -> None:
self.mapping = mapping or {}
- self.mapping_trie = self._build_trie(self.mapping)
+ self.mapping_trie = new_trie(
+ tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
+ )
self._supported_table_args: t.Tuple[str, ...] = tuple()
- def _build_trie(self, schema: t.Dict) -> t.Dict:
- return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
-
def _depth(self) -> int:
return dict_depth(self.mapping)
@@ -179,6 +178,64 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
}
)
+ def add_table(
+ self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ ) -> None:
+ """
+ Register or update a table. Updates are only performed if a new column mapping is provided.
+
+ Args:
+ table: the `Table` expression instance or string representing the table.
+ column_mapping: a column mapping that describes the structure of the table.
+ """
+ normalized_table = self._normalize_table(self._ensure_table(table))
+ normalized_column_mapping = {
+ self._normalize_name(key): value
+ for key, value in ensure_column_mapping(column_mapping).items()
+ }
+
+ schema = self.find(normalized_table, raise_on_missing=False)
+ if schema and not normalized_column_mapping:
+ return
+
+ parts = self.table_parts(normalized_table)
+
+ _nested_set(
+ self.mapping,
+ tuple(reversed(parts)),
+ normalized_column_mapping,
+ )
+ new_trie([parts], self.mapping_trie)
+
+ def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
+ table_ = self._normalize_table(self._ensure_table(table))
+ schema = self.find(table_)
+
+ if schema is None:
+ return []
+
+ if not only_visible or not self.visible:
+ return list(schema)
+
+ visible = self._nested_get(self.table_parts(table_), self.visible)
+ return [col for col in schema if col in visible] # type: ignore
+
+ def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
+ column_name = self._normalize_name(column if isinstance(column, str) else column.this)
+ table_ = self._normalize_table(self._ensure_table(table))
+
+ table_schema = self.find(table_, raise_on_missing=False)
+ if table_schema:
+ column_type = table_schema.get(column_name)
+
+ if isinstance(column_type, exp.DataType):
+ return column_type
+ elif isinstance(column_type, str):
+ return self._to_data_type(column_type.upper())
+ raise SchemaError(f"Unknown column type '{column_type}'")
+
+ return exp.DataType.build("unknown")
+
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
@@ -206,84 +263,37 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
- def add_table(
- self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
- ) -> None:
- """
- Register or update a table. Updates are only performed if a new column mapping is provided.
+ def _normalize_table(self, table: exp.Table) -> exp.Table:
+ normalized_table = table.copy()
+ for arg in TABLE_ARGS:
+ value = normalized_table.args.get(arg)
+ if isinstance(value, (str, exp.Identifier)):
+ normalized_table.set(arg, self._normalize_name(value))
- Args:
- table: the `Table` expression instance or string representing the table.
- column_mapping: a column mapping that describes the structure of the table.
- """
- table_ = self._ensure_table(table)
- column_mapping = ensure_column_mapping(column_mapping)
- schema = self.find(table_, raise_on_missing=False)
-
- if schema and not column_mapping:
- return
-
- _nested_set(
- self.mapping,
- list(reversed(self.table_parts(table_))),
- column_mapping,
- )
- self.mapping_trie = self._build_trie(self.mapping)
+ return normalized_table
- def _normalize_name(self, name: str) -> str:
+ def _normalize_name(self, name: str | exp.Identifier) -> str:
try:
- identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
- name, read=self.dialect, into=exp.Identifier
- )
- except:
- identifier = exp.to_identifier(name)
- assert isinstance(identifier, exp.Identifier)
+ identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
+ except ParseError:
+ return name if isinstance(name, str) else name.name
- if identifier.quoted:
- return identifier.name
- return identifier.name.lower()
+ return identifier.name if identifier.quoted else identifier.name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
- table_ = exp.to_table(table)
+ if isinstance(table, exp.Table):
+ return table
+ table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
if not table_:
raise SchemaError(f"Not a valid table '{table}'")
return table_
- def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
- table_ = self._ensure_table(table)
- schema = self.find(table_)
-
- if schema is None:
- return []
-
- if not only_visible or not self.visible:
- return list(schema)
-
- visible = self._nested_get(self.table_parts(table_), self.visible)
- return [col for col in schema if col in visible] # type: ignore
-
- def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
- column_name = column if isinstance(column, str) else column.name
- table_ = exp.to_table(table)
- if table_:
- table_schema = self.find(table_, raise_on_missing=False)
- if table_schema:
- column_type = table_schema.get(column_name)
-
- if isinstance(column_type, exp.DataType):
- return column_type
- elif isinstance(column_type, str):
- return self._to_data_type(column_type.upper())
- raise SchemaError(f"Unknown column type '{column_type}'")
- return exp.DataType(this=exp.DataType.Type.UNKNOWN)
- raise SchemaError(f"Could not convert table '{table}'")
-
def _to_data_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@@ -313,7 +323,7 @@ def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
return MappingSchema(schema, dialect=dialect)
-def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
+def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
@@ -371,7 +381,7 @@ def _nested_get(
return d
-def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
+def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
@@ -384,11 +394,11 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
Args:
d: dictionary to update.
- keys: the keys that makeup the path to `value`.
- value: the value to set in the dictionary for the given key path.
+ keys: the keys that makeup the path to `value`.
+ value: the value to set in the dictionary for the given key path.
- Returns:
- The (possibly) updated dictionary.
+ Returns:
+ The (possibly) updated dictionary.
"""
if not keys:
return d