diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-05-03 09:12:24 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-05-03 09:12:24 +0000 |
commit | 98d5537435b2951b36c45f1fda667fa27c165794 (patch) | |
tree | d26b4dfa6cf91847100fe10a94a04dcc2ad36a86 /sqlglot/schema.py | |
parent | Adding upstream version 11.5.2. (diff) | |
download | sqlglot-98d5537435b2951b36c45f1fda667fa27c165794.tar.xz sqlglot-98d5537435b2951b36c45f1fda667fa27c165794.zip |
Adding upstream version 11.7.1.upstream/11.7.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r-- | sqlglot/schema.py | 156 |
1 files changed, 83 insertions, 73 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 8e39c7f..5d60eb9 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,7 +5,7 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.errors import SchemaError +from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie @@ -75,12 +75,11 @@ class AbstractMappingSchema(t.Generic[T]): mapping: dict | None = None, ) -> None: self.mapping = mapping or {} - self.mapping_trie = self._build_trie(self.mapping) + self.mapping_trie = new_trie( + tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) + ) self._supported_table_args: t.Tuple[str, ...] = tuple() - def _build_trie(self, schema: t.Dict) -> t.Dict: - return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth())) - def _depth(self) -> int: return dict_depth(self.mapping) @@ -179,6 +178,64 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): } ) + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: + """ + Register or update a table. Updates are only performed if a new column mapping is provided. + + Args: + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. + """ + normalized_table = self._normalize_table(self._ensure_table(table)) + normalized_column_mapping = { + self._normalize_name(key): value + for key, value in ensure_column_mapping(column_mapping).items() + } + + schema = self.find(normalized_table, raise_on_missing=False) + if schema and not normalized_column_mapping: + return + + parts = self.table_parts(normalized_table) + + _nested_set( + self.mapping, + tuple(reversed(parts)), + normalized_column_mapping, + ) + new_trie([parts], self.mapping_trie) + + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + table_ = self._normalize_table(self._ensure_table(table)) + schema = self.find(table_) + + if schema is None: + return [] + + if not only_visible or not self.visible: + return list(schema) + + visible = self._nested_get(self.table_parts(table_), self.visible) + return [col for col in schema if col in visible] # type: ignore + + def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: + column_name = self._normalize_name(column if isinstance(column, str) else column.this) + table_ = self._normalize_table(self._ensure_table(table)) + + table_schema = self.find(table_, raise_on_missing=False) + if table_schema: + column_type = table_schema.get(column_name) + + if isinstance(column_type, exp.DataType): + return column_type + elif isinstance(column_type, str): + return self._to_data_type(column_type.upper()) + raise SchemaError(f"Unknown column type '{column_type}'") + + return exp.DataType.build("unknown") + def _normalize(self, schema: t.Dict) -> t.Dict: """ Converts all identifiers in the schema into lowercase, unless they're quoted. @@ -206,84 +263,37 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return normalized_mapping - def add_table( - self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None - ) -> None: - """ - Register or update a table. Updates are only performed if a new column mapping is provided. + def _normalize_table(self, table: exp.Table) -> exp.Table: + normalized_table = table.copy() + for arg in TABLE_ARGS: + value = normalized_table.args.get(arg) + if isinstance(value, (str, exp.Identifier)): + normalized_table.set(arg, self._normalize_name(value)) - Args: - table: the `Table` expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table. - """ - table_ = self._ensure_table(table) - column_mapping = ensure_column_mapping(column_mapping) - schema = self.find(table_, raise_on_missing=False) - - if schema and not column_mapping: - return - - _nested_set( - self.mapping, - list(reversed(self.table_parts(table_))), - column_mapping, - ) - self.mapping_trie = self._build_trie(self.mapping) + return normalized_table - def _normalize_name(self, name: str) -> str: + def _normalize_name(self, name: str | exp.Identifier) -> str: try: - identifier: t.Optional[exp.Expression] = sqlglot.parse_one( - name, read=self.dialect, into=exp.Identifier - ) - except: - identifier = exp.to_identifier(name) - assert isinstance(identifier, exp.Identifier) + identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) + except ParseError: + return name if isinstance(name, str) else name.name - if identifier.quoted: - return identifier.name - return identifier.name.lower() + return identifier.name if identifier.quoted else identifier.name.lower() def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 def _ensure_table(self, table: exp.Table | str) -> exp.Table: - table_ = exp.to_table(table) + if isinstance(table, exp.Table): + return table + table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) if not table_: raise SchemaError(f"Not a valid table '{table}'") return table_ - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: - table_ = self._ensure_table(table) - schema = self.find(table_) - - if schema is None: - return [] - - if not only_visible or not self.visible: - return list(schema) - - visible = self._nested_get(self.table_parts(table_), self.visible) - return [col for col in schema if col in visible] # type: ignore - - def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: - column_name = column if isinstance(column, str) else column.name - table_ = exp.to_table(table) - if table_: - table_schema = self.find(table_, raise_on_missing=False) - if table_schema: - column_type = table_schema.get(column_name) - - if isinstance(column_type, exp.DataType): - return column_type - elif isinstance(column_type, str): - return self._to_data_type(column_type.upper()) - raise SchemaError(f"Unknown column type '{column_type}'") - return exp.DataType(this=exp.DataType.Type.UNKNOWN) - raise SchemaError(f"Could not convert table '{table}'") - def _to_data_type(self, schema_type: str) -> exp.DataType: """ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. @@ -313,7 +323,7 @@ def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: return MappingSchema(schema, dialect=dialect) -def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): +def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: if isinstance(mapping, dict): return mapping elif isinstance(mapping, str): @@ -371,7 +381,7 @@ def _nested_get( return d -def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: +def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary @@ -384,11 +394,11 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: Args: d: dictionary to update. - keys: the keys that makeup the path to `value`. - value: the value to set in the dictionary for the given key path. + keys: the keys that makeup the path to `value`. + value: the value to set in the dictionary for the given key path. - Returns: - The (possibly) updated dictionary. + Returns: + The (possibly) updated dictionary. """ if not keys: return d |