summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py215
1 files changed, 99 insertions, 116 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index fcf7291..f6f303b 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
-from sqlglot.helper import csv_reader
+from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
@@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
+T = t.TypeVar("T")
+
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@@ -57,8 +59,81 @@ class Schema(abc.ABC):
The resulting column type.
"""
+ @property
+ def supported_table_args(self) -> t.Tuple[str, ...]:
+ """
+ Table arguments this schema support, e.g. `("this", "db", "catalog")`
+ """
+ raise NotImplementedError
+
+
+class AbstractMappingSchema(t.Generic[T]):
+ def __init__(
+ self,
+ mapping: dict | None = None,
+ ) -> None:
+ self.mapping = mapping or {}
+ self.mapping_trie = self._build_trie(self.mapping)
+ self._supported_table_args: t.Tuple[str, ...] = tuple()
+
+ def _build_trie(self, schema: t.Dict) -> t.Dict:
+ return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
+
+ def _depth(self) -> int:
+ return dict_depth(self.mapping)
+
+ @property
+ def supported_table_args(self) -> t.Tuple[str, ...]:
+ if not self._supported_table_args and self.mapping:
+ depth = self._depth()
+
+ if not depth: # None
+ self._supported_table_args = tuple()
+ elif 1 <= depth <= 3:
+ self._supported_table_args = TABLE_ARGS[:depth]
+ else:
+ raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
+
+ return self._supported_table_args
+
+ def table_parts(self, table: exp.Table) -> t.List[str]:
+ if isinstance(table.this, exp.ReadCSV):
+ return [table.this.name]
+ return [table.text(part) for part in TABLE_ARGS if table.text(part)]
+
+ def find(
+ self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
+ ) -> t.Optional[T]:
+ parts = self.table_parts(table)[0 : len(self.supported_table_args)]
+ value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
+
+ if value == 0:
+ if raise_on_missing:
+ raise SchemaError(f"Cannot find mapping for {table}.")
+ else:
+ return None
+ elif value == 1:
+ possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
+ if len(possibilities) == 1:
+ parts.extend(possibilities[0])
+ else:
+ message = ", ".join(".".join(parts) for parts in possibilities)
+ if raise_on_missing:
+ raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
+ return None
+ return self._nested_get(parts, raise_on_missing=raise_on_missing)
-class MappingSchema(Schema):
+ def _nested_get(
+ self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
+ ) -> t.Optional[t.Any]:
+ return _nested_get(
+ d or self.mapping,
+ *zip(self.supported_table_args, reversed(parts)),
+ raise_on_missing=raise_on_missing,
+ )
+
+
+class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"""
Schema based on a nested mapping.
@@ -82,17 +157,17 @@ class MappingSchema(Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
- self.schema = schema or {}
+ super().__init__(schema)
self.visible = visible or {}
- self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect
- self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
- self._supported_table_args: t.Tuple[str, ...] = tuple()
+ self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
+ "STR": exp.DataType.Type.TEXT,
+ }
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema(
- schema=mapping_schema.schema,
+ schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
)
@@ -100,27 +175,13 @@ class MappingSchema(Schema):
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
- "schema": self.schema.copy(),
+ "schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
- @property
- def supported_table_args(self):
- if not self._supported_table_args and self.schema:
- depth = _dict_depth(self.schema)
-
- if not depth or depth == 1: # {}
- self._supported_table_args = tuple()
- elif 2 <= depth <= 4:
- self._supported_table_args = TABLE_ARGS[: depth - 1]
- else:
- raise SchemaError(f"Invalid schema shape. Depth: {depth}")
-
- return self._supported_table_args
-
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@@ -133,17 +194,21 @@ class MappingSchema(Schema):
"""
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
- schema = self.find_schema(table_, raise_on_missing=False)
+ schema = self.find(table_, raise_on_missing=False)
if schema and not column_mapping:
return
_nested_set(
- self.schema,
+ self.mapping,
list(reversed(self.table_parts(table_))),
column_mapping,
)
- self.schema_trie = self._build_trie(self.schema)
+ self.mapping_trie = self._build_trie(self.mapping)
+
+ def _depth(self) -> int:
+ # The columns themselves are a mapping, but we don't want to include those
+ return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
@@ -153,16 +218,9 @@ class MappingSchema(Schema):
return table_
- def table_parts(self, table: exp.Table) -> t.List[str]:
- return [table.text(part) for part in TABLE_ARGS if table.text(part)]
-
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
-
- if not isinstance(table_.this, exp.Identifier):
- return fs_get(table) # type: ignore
-
- schema = self.find_schema(table_)
+ schema = self.find(table_)
if schema is None:
raise SchemaError(f"Could not find table schema {table}")
@@ -173,36 +231,13 @@ class MappingSchema(Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
- def find_schema(
- self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
- ) -> t.Optional[t.Dict[str, str]]:
- parts = self.table_parts(table)[0 : len(self.supported_table_args)]
- value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
-
- if value == 0:
- if raise_on_missing:
- raise SchemaError(f"Cannot find schema for {table}.")
- else:
- return None
- elif value == 1:
- possibilities = flatten_schema(trie)
- if len(possibilities) == 1:
- parts.extend(possibilities[0])
- else:
- message = ", ".join(".".join(parts) for parts in possibilities)
- if raise_on_missing:
- raise SchemaError(f"Ambiguous schema for {table}: {message}.")
- return None
-
- return self._nested_get(parts, raise_on_missing=raise_on_missing)
-
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
- table_schema = self.find_schema(table_)
+ table_schema = self.find(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
raise SchemaError(f"Could not convert table '{table}'")
@@ -228,18 +263,6 @@ class MappingSchema(Schema):
return self._type_mapping_cache[schema_type]
- def _build_trie(self, schema: t.Dict):
- return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
-
- def _nested_get(
- self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
- ) -> t.Optional[t.Any]:
- return _nested_get(
- d or self.schema,
- *zip(self.supported_table_args, reversed(parts)),
- raise_on_missing=raise_on_missing,
- )
-
def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
@@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
-def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
+def flatten_schema(
+ schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
+) -> t.List[t.List[str]]:
tables = []
keys = keys or []
- depth = _dict_depth(schema)
for k, v in schema.items():
- if depth >= 3:
- tables.extend(flatten_schema(v, keys + [k]))
- elif depth == 2:
+ if depth >= 2:
+ tables.extend(flatten_schema(v, depth - 1, keys + [k]))
+ elif depth == 1:
tables.append(keys + [k])
return tables
-def fs_get(table: exp.Table) -> t.List[str]:
- name = table.this.name
-
- if name.upper() == "READ_CSV":
- with csv_reader(table) as reader:
- return next(reader)
-
- raise ValueError(f"Cannot read schema for {table}")
-
-
def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
@@ -310,7 +324,7 @@ def _nested_get(
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
- raise ValueError(f"Unknown {name}")
+ raise ValueError(f"Unknown {name}: {key}")
return None
return d
@@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
subd[keys[-1]] = value
return d
-
-
-def _dict_depth(d: t.Dict) -> int:
- """
- Get the nesting depth of a dictionary.
-
- For example:
- >>> _dict_depth(None)
- 0
- >>> _dict_depth({})
- 1
- >>> _dict_depth({"a": "b"})
- 1
- >>> _dict_depth({"a": {}})
- 2
- >>> _dict_depth({"a": {"b": {}}})
- 3
-
- Args:
- d (dict): dictionary
- Returns:
- int: depth
- """
- try:
- return 1 + _dict_depth(next(iter(d.values())))
- except AttributeError:
- # d doesn't have attribute "values"
- return 0
- except StopIteration:
- # d.values() returns an empty sequence
- return 1