From f2981e8e4d28233864f1ca06ecec45ab80bf9eae Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 19 Nov 2022 15:50:39 +0100 Subject: Merging upstream version 10.0.8. Signed-off-by: Daniel Baumann --- sqlglot/schema.py | 215 +++++++++++++++++++++++++----------------------------- 1 file changed, 99 insertions(+), 116 deletions(-) (limited to 'sqlglot/schema.py') diff --git a/sqlglot/schema.py b/sqlglot/schema.py index fcf7291..f6f303b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import expressions as exp from sqlglot.errors import SchemaError -from sqlglot.helper import csv_reader +from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie if t.TYPE_CHECKING: @@ -15,6 +15,8 @@ if t.TYPE_CHECKING: TABLE_ARGS = ("this", "db", "catalog") +T = t.TypeVar("T") + class Schema(abc.ABC): """Abstract base class for database schemas""" @@ -57,8 +59,81 @@ class Schema(abc.ABC): The resulting column type. """ + @property + def supported_table_args(self) -> t.Tuple[str, ...]: + """ + Table arguments this schema support, e.g. `("this", "db", "catalog")` + """ + raise NotImplementedError + + +class AbstractMappingSchema(t.Generic[T]): + def __init__( + self, + mapping: dict | None = None, + ) -> None: + self.mapping = mapping or {} + self.mapping_trie = self._build_trie(self.mapping) + self._supported_table_args: t.Tuple[str, ...] = tuple() + + def _build_trie(self, schema: t.Dict) -> t.Dict: + return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth())) + + def _depth(self) -> int: + return dict_depth(self.mapping) + + @property + def supported_table_args(self) -> t.Tuple[str, ...]: + if not self._supported_table_args and self.mapping: + depth = self._depth() + + if not depth: # None + self._supported_table_args = tuple() + elif 1 <= depth <= 3: + self._supported_table_args = TABLE_ARGS[:depth] + else: + raise SchemaError(f"Invalid mapping shape. Depth: {depth}") + + return self._supported_table_args + + def table_parts(self, table: exp.Table) -> t.List[str]: + if isinstance(table.this, exp.ReadCSV): + return [table.this.name] + return [table.text(part) for part in TABLE_ARGS if table.text(part)] + + def find( + self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True + ) -> t.Optional[T]: + parts = self.table_parts(table)[0 : len(self.supported_table_args)] + value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) + + if value == 0: + if raise_on_missing: + raise SchemaError(f"Cannot find mapping for {table}.") + else: + return None + elif value == 1: + possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + message = ", ".join(".".join(parts) for parts in possibilities) + if raise_on_missing: + raise SchemaError(f"Ambiguous mapping for {table}: {message}.") + return None + return self._nested_get(parts, raise_on_missing=raise_on_missing) -class MappingSchema(Schema): + def _nested_get( + self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True + ) -> t.Optional[t.Any]: + return _nested_get( + d or self.mapping, + *zip(self.supported_table_args, reversed(parts)), + raise_on_missing=raise_on_missing, + ) + + +class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): """ Schema based on a nested mapping. @@ -82,17 +157,17 @@ class MappingSchema(Schema): visible: t.Optional[t.Dict] = None, dialect: t.Optional[str] = None, ) -> None: - self.schema = schema or {} + super().__init__(schema) self.visible = visible or {} - self.schema_trie = self._build_trie(self.schema) self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {} - self._supported_table_args: t.Tuple[str, ...] = tuple() + self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = { + "STR": exp.DataType.Type.TEXT, + } @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: return MappingSchema( - schema=mapping_schema.schema, + schema=mapping_schema.mapping, visible=mapping_schema.visible, dialect=mapping_schema.dialect, ) @@ -100,27 +175,13 @@ class MappingSchema(Schema): def copy(self, **kwargs) -> MappingSchema: return MappingSchema( **{ # type: ignore - "schema": self.schema.copy(), + "schema": self.mapping.copy(), "visible": self.visible.copy(), "dialect": self.dialect, **kwargs, } ) - @property - def supported_table_args(self): - if not self._supported_table_args and self.schema: - depth = _dict_depth(self.schema) - - if not depth or depth == 1: # {} - self._supported_table_args = tuple() - elif 2 <= depth <= 4: - self._supported_table_args = TABLE_ARGS[: depth - 1] - else: - raise SchemaError(f"Invalid schema shape. Depth: {depth}") - - return self._supported_table_args - def add_table( self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None ) -> None: @@ -133,17 +194,21 @@ class MappingSchema(Schema): """ table_ = self._ensure_table(table) column_mapping = ensure_column_mapping(column_mapping) - schema = self.find_schema(table_, raise_on_missing=False) + schema = self.find(table_, raise_on_missing=False) if schema and not column_mapping: return _nested_set( - self.schema, + self.mapping, list(reversed(self.table_parts(table_))), column_mapping, ) - self.schema_trie = self._build_trie(self.schema) + self.mapping_trie = self._build_trie(self.mapping) + + def _depth(self) -> int: + # The columns themselves are a mapping, but we don't want to include those + return super()._depth() - 1 def _ensure_table(self, table: exp.Table | str) -> exp.Table: table_ = exp.to_table(table) @@ -153,16 +218,9 @@ class MappingSchema(Schema): return table_ - def table_parts(self, table: exp.Table) -> t.List[str]: - return [table.text(part) for part in TABLE_ARGS if table.text(part)] - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: table_ = self._ensure_table(table) - - if not isinstance(table_.this, exp.Identifier): - return fs_get(table) # type: ignore - - schema = self.find_schema(table_) + schema = self.find(table_) if schema is None: raise SchemaError(f"Could not find table schema {table}") @@ -173,36 +231,13 @@ class MappingSchema(Schema): visible = self._nested_get(self.table_parts(table_), self.visible) return [col for col in schema if col in visible] # type: ignore - def find_schema( - self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True - ) -> t.Optional[t.Dict[str, str]]: - parts = self.table_parts(table)[0 : len(self.supported_table_args)] - value, trie = in_trie(self.schema_trie if trie is None else trie, parts) - - if value == 0: - if raise_on_missing: - raise SchemaError(f"Cannot find schema for {table}.") - else: - return None - elif value == 1: - possibilities = flatten_schema(trie) - if len(possibilities) == 1: - parts.extend(possibilities[0]) - else: - message = ", ".join(".".join(parts) for parts in possibilities) - if raise_on_missing: - raise SchemaError(f"Ambiguous schema for {table}: {message}.") - return None - - return self._nested_get(parts, raise_on_missing=raise_on_missing) - def get_column_type( self, table: exp.Table | str, column: exp.Column | str ) -> exp.DataType.Type: column_name = column if isinstance(column, str) else column.name table_ = exp.to_table(table) if table_: - table_schema = self.find_schema(table_) + table_schema = self.find(table_) schema_type = table_schema.get(column_name).upper() # type: ignore return self._convert_type(schema_type) raise SchemaError(f"Could not convert table '{table}'") @@ -228,18 +263,6 @@ class MappingSchema(Schema): return self._type_mapping_cache[schema_type] - def _build_trie(self, schema: t.Dict): - return new_trie(tuple(reversed(t)) for t in flatten_schema(schema)) - - def _nested_get( - self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True - ) -> t.Optional[t.Any]: - return _nested_get( - d or self.schema, - *zip(self.supported_table_args, reversed(parts)), - raise_on_missing=raise_on_missing, - ) - def ensure_schema(schema: t.Any) -> Schema: if isinstance(schema, Schema): @@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): raise ValueError(f"Invalid mapping provided: {type(mapping)}") -def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]: +def flatten_schema( + schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None +) -> t.List[t.List[str]]: tables = [] keys = keys or [] - depth = _dict_depth(schema) for k, v in schema.items(): - if depth >= 3: - tables.extend(flatten_schema(v, keys + [k])) - elif depth == 2: + if depth >= 2: + tables.extend(flatten_schema(v, depth - 1, keys + [k])) + elif depth == 1: tables.append(keys + [k]) return tables -def fs_get(table: exp.Table) -> t.List[str]: - name = table.this.name - - if name.upper() == "READ_CSV": - with csv_reader(table) as reader: - return next(reader) - - raise ValueError(f"Cannot read schema for {table}") - - def _nested_get( d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True ) -> t.Optional[t.Any]: @@ -310,7 +324,7 @@ def _nested_get( if d is None: if raise_on_missing: name = "table" if name == "this" else name - raise ValueError(f"Unknown {name}") + raise ValueError(f"Unknown {name}: {key}") return None return d @@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: subd[keys[-1]] = value return d - - -def _dict_depth(d: t.Dict) -> int: - """ - Get the nesting depth of a dictionary. - - For example: - >>> _dict_depth(None) - 0 - >>> _dict_depth({}) - 1 - >>> _dict_depth({"a": "b"}) - 1 - >>> _dict_depth({"a": {}}) - 2 - >>> _dict_depth({"a": {"b": {}}}) - 3 - - Args: - d (dict): dictionary - Returns: - int: depth - """ - try: - return 1 + _dict_depth(next(iter(d.values()))) - except AttributeError: - # d doesn't have attribute "values" - return 0 - except StopIteration: - # d.values() returns an empty sequence - return 1 -- cgit v1.2.3