summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py79
1 files changed, 62 insertions, 17 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 12cf0b1..7a3c88b 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -31,14 +31,19 @@ class Schema(abc.ABC):
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ match_depth: bool = True,
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
+ The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
+ match_depth: whether to enforce that the table must match the schema's depth or not.
"""
@abc.abstractmethod
@@ -47,6 +52,7 @@ class Schema(abc.ABC):
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> t.List[str]:
"""
Get the column names for a table.
@@ -55,6 +61,7 @@ class Schema(abc.ABC):
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
@@ -66,6 +73,7 @@ class Schema(abc.ABC):
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> exp.DataType:
"""
Get the `sqlglot.exp.DataType` type of a column in the schema.
@@ -74,6 +82,7 @@ class Schema(abc.ABC):
table: the source table.
column: the target column.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
@@ -99,7 +108,7 @@ class AbstractMappingSchema(t.Generic[T]):
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
- tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
+ tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
@@ -107,13 +116,13 @@ class AbstractMappingSchema(t.Generic[T]):
def empty(self) -> bool:
return not self.mapping
- def _depth(self) -> int:
+ def depth(self) -> int:
return dict_depth(self.mapping)
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
- depth = self._depth()
+ depth = self.depth()
if not depth: # None
self._supported_table_args = tuple()
@@ -191,6 +200,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
self.visible = visible or {}
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
+ self._depth = 0
super().__init__(self._normalize(schema or {}))
@@ -200,6 +210,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
+ normalize=mapping_schema.normalize,
)
def copy(self, **kwargs) -> MappingSchema:
@@ -208,6 +219,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
+ "normalize": self.normalize,
**kwargs,
}
)
@@ -217,19 +229,30 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ match_depth: bool = True,
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
+ The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
+ match_depth: whether to enforce that the table must match the schema's depth or not.
"""
- normalized_table = self._normalize_table(table, dialect=dialect)
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
+
+ if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
+ raise SchemaError(
+ f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
+ f"schema's nesting level: {self.depth()}."
+ )
normalized_column_mapping = {
- self._normalize_name(key, dialect=dialect): value
+ self._normalize_name(key, dialect=dialect, normalize=normalize): value
for key, value in ensure_column_mapping(column_mapping).items()
}
@@ -247,8 +270,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> t.List[str]:
- normalized_table = self._normalize_table(table, dialect=dialect)
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
schema = self.find(normalized_table)
if schema is None:
@@ -265,11 +289,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> exp.DataType:
- normalized_table = self._normalize_table(table, dialect=dialect)
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
normalized_column_name = self._normalize_name(
- column if isinstance(column, str) else column.this, dialect=dialect
+ column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
)
table_schema = self.find(normalized_table, raise_on_missing=False)
@@ -293,12 +318,16 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Returns:
The normalized schema mapping.
"""
+ normalized_mapping: t.Dict = {}
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
- normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = nested_get(schema, *zip(keys, keys))
- assert columns is not None
+
+ if not isinstance(columns, dict):
+ raise SchemaError(
+ f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
+ )
normalized_keys = [
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
@@ -312,7 +341,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
- def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
+ def _normalize_table(
+ self,
+ table: exp.Table | str,
+ dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ ) -> exp.Table:
normalized_table = exp.maybe_parse(
table, into=exp.Table, dialect=dialect or self.dialect, copy=True
)
@@ -322,15 +356,24 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(
arg,
- exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
+ exp.to_identifier(
+ self._normalize_name(
+ value, dialect=dialect, is_table=True, normalize=normalize
+ )
+ ),
)
return normalized_table
def _normalize_name(
- self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
+ self,
+ name: str | exp.Identifier,
+ dialect: DialectType = None,
+ is_table: bool = False,
+ normalize: t.Optional[bool] = None,
) -> str:
dialect = dialect or self.dialect
+ normalize = self.normalize if normalize is None else normalize
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
@@ -338,16 +381,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return name if isinstance(name, str) else name.name
name = identifier.name
- if not self.normalize:
+ if not normalize:
return name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
- def _depth(self) -> int:
- # The columns themselves are a mapping, but we don't want to include those
- return super()._depth() - 1
+ def depth(self) -> int:
+ if not self.empty and not self._depth:
+ # The columns themselves are a mapping, but we don't want to include those
+ self._depth = super().depth() - 1
+ return self._depth
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""