summaryrefslogtreecommitdiffstats
path: root/sqlglot/schema.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-02 23:59:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-02 23:59:46 +0000
commit20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch)
treec000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/schema.py
parentReleasing debian version 12.2.0-1. (diff)
downloadsqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz
sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/schema.py')
-rw-r--r--sqlglot/schema.py201
1 files changed, 131 insertions, 70 deletions
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 5d60eb9..f1c4a09 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,6 +5,8 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
+from sqlglot._typing import T
+from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
@@ -17,62 +19,83 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
-T = t.TypeVar("T")
-
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
+ dialect: DialectType
+
@abc.abstractmethod
def add_table(
- self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ self,
+ table: exp.Table | str,
+ column_mapping: t.Optional[ColumnMapping] = None,
+ dialect: DialectType = None,
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
Args:
- table: table expression instance or string representing the table.
+ table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
@abc.abstractmethod
- def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
+ def column_names(
+ self,
+ table: exp.Table | str,
+ only_visible: bool = False,
+ dialect: DialectType = None,
+ ) -> t.List[str]:
"""
Get the column names for a table.
Args:
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
Returns:
The list of column names.
"""
@abc.abstractmethod
- def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
+ def get_column_type(
+ self,
+ table: exp.Table | str,
+ column: exp.Column,
+ dialect: DialectType = None,
+ ) -> exp.DataType:
"""
- Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
+ Get the `sqlglot.exp.DataType` type of a column in the schema.
Args:
table: the source table.
column: the target column.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
Returns:
The resulting column type.
"""
@property
+ @abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
"""
Table arguments this schema support, e.g. `("this", "db", "catalog")`
"""
- raise NotImplementedError
+
+ @property
+ def empty(self) -> bool:
+ """Returns whether or not the schema is empty."""
+ return True
class AbstractMappingSchema(t.Generic[T]):
def __init__(
self,
- mapping: dict | None = None,
+ mapping: t.Optional[t.Dict] = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
@@ -80,6 +103,10 @@ class AbstractMappingSchema(t.Generic[T]):
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
+ @property
+ def empty(self) -> bool:
+ return not self.mapping
+
def _depth(self) -> int:
return dict_depth(self.mapping)
@@ -110,8 +137,10 @@ class AbstractMappingSchema(t.Generic[T]):
if value == 0:
return None
- elif value == 1:
+
+ if value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
+
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
@@ -119,12 +148,13 @@ class AbstractMappingSchema(t.Generic[T]):
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
- return self._nested_get(parts, raise_on_missing=raise_on_missing)
- def _nested_get(
+ return self.nested_get(parts, raise_on_missing=raise_on_missing)
+
+ def nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
- return _nested_get(
+ return nested_get(
d or self.mapping,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
@@ -136,17 +166,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Schema based on a nested mapping.
Args:
- schema (dict): Mapping in one of the following forms:
+ schema: Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
4. None - Tables will be added later
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
+ visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
+ dialect: The dialect to be used for custom type mappings & parsing string arguments.
+ normalize: Whether to normalize identifier names according to the given dialect or not.
"""
def __init__(
@@ -154,10 +185,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: DialectType = None,
+ normalize: bool = True,
) -> None:
self.dialect = dialect
self.visible = visible or {}
+ self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
+
super().__init__(self._normalize(schema or {}))
@classmethod
@@ -179,7 +213,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
)
def add_table(
- self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ self,
+ table: exp.Table | str,
+ column_mapping: t.Optional[ColumnMapping] = None,
+ dialect: DialectType = None,
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
@@ -187,10 +224,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
- normalized_table = self._normalize_table(self._ensure_table(table))
+ normalized_table = self._normalize_table(
+ self._ensure_table(table, dialect=dialect), dialect=dialect
+ )
normalized_column_mapping = {
- self._normalize_name(key): value
+ self._normalize_name(key, dialect=dialect): value
for key, value in ensure_column_mapping(column_mapping).items()
}
@@ -200,38 +240,51 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
parts = self.table_parts(normalized_table)
- _nested_set(
- self.mapping,
- tuple(reversed(parts)),
- normalized_column_mapping,
- )
+ nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
new_trie([parts], self.mapping_trie)
- def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
- table_ = self._normalize_table(self._ensure_table(table))
- schema = self.find(table_)
+ def column_names(
+ self,
+ table: exp.Table | str,
+ only_visible: bool = False,
+ dialect: DialectType = None,
+ ) -> t.List[str]:
+ normalized_table = self._normalize_table(
+ self._ensure_table(table, dialect=dialect), dialect=dialect
+ )
+ schema = self.find(normalized_table)
if schema is None:
return []
if not only_visible or not self.visible:
return list(schema)
- visible = self._nested_get(self.table_parts(table_), self.visible)
- return [col for col in schema if col in visible] # type: ignore
+ visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
+ return [col for col in schema if col in visible]
- def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
- column_name = self._normalize_name(column if isinstance(column, str) else column.this)
- table_ = self._normalize_table(self._ensure_table(table))
+ def get_column_type(
+ self,
+ table: exp.Table | str,
+ column: exp.Column,
+ dialect: DialectType = None,
+ ) -> exp.DataType:
+ normalized_table = self._normalize_table(
+ self._ensure_table(table, dialect=dialect), dialect=dialect
+ )
+ normalized_column_name = self._normalize_name(
+ column if isinstance(column, str) else column.this, dialect=dialect
+ )
- table_schema = self.find(table_, raise_on_missing=False)
+ table_schema = self.find(normalized_table, raise_on_missing=False)
if table_schema:
- column_type = table_schema.get(column_name)
+ column_type = table_schema.get(normalized_column_name)
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
- return self._to_data_type(column_type.upper())
+ return self._to_data_type(column_type.upper(), dialect=dialect)
+
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType.build("unknown")
@@ -250,81 +303,88 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
- columns = _nested_get(schema, *zip(keys, keys))
+ columns = nested_get(schema, *zip(keys, keys))
assert columns is not None
- normalized_keys = [self._normalize_name(key) for key in keys]
+ normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
for column_name, column_type in columns.items():
- _nested_set(
+ nested_set(
normalized_mapping,
- normalized_keys + [self._normalize_name(column_name)],
+ normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
column_type,
)
return normalized_mapping
- def _normalize_table(self, table: exp.Table) -> exp.Table:
+ def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table:
normalized_table = table.copy()
+
for arg in TABLE_ARGS:
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
- normalized_table.set(arg, self._normalize_name(value))
+ normalized_table.set(
+ arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
+ )
return normalized_table
- def _normalize_name(self, name: str | exp.Identifier) -> str:
+ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
+ dialect = dialect or self.dialect
+
try:
- identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
+ identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
- return identifier.name if identifier.quoted else identifier.name.lower()
+ name = identifier.name
+
+ if not self.normalize or identifier.quoted:
+ return name
+
+ return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
- def _ensure_table(self, table: exp.Table | str) -> exp.Table:
- if isinstance(table, exp.Table):
- return table
-
- table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
- if not table_:
- raise SchemaError(f"Not a valid table '{table}'")
+ def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
+ return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)
- return table_
-
- def _to_data_type(self, schema_type: str) -> exp.DataType:
+ def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
- Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
+ Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
Args:
schema_type: the type we want to convert.
+ dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
Returns:
The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
+ dialect = dialect or self.dialect
+
try:
- expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
- if expression is None:
- raise ValueError(f"Could not parse {schema_type}")
- self._type_mapping_cache[schema_type] = expression # type: ignore
+ expression = exp.DataType.build(schema_type, dialect=dialect)
+ self._type_mapping_cache[schema_type] = expression
except AttributeError:
- raise SchemaError(f"Failed to convert type {schema_type}")
+ in_dialect = f" in dialect {dialect}" if dialect else ""
+ raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
return self._type_mapping_cache[schema_type]
-def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
+def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
- return MappingSchema(schema, dialect=dialect)
+ return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
- if isinstance(mapping, dict):
+ if mapping is None:
+ return {}
+ elif isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
col_name_type_strs = [x.strip() for x in mapping.split(",")]
@@ -334,11 +394,10 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
- return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
- elif mapping is None:
- return {}
+
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
@@ -353,10 +412,11 @@ def flatten_schema(
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 1:
tables.append(keys + [k])
+
return tables
-def _nested_get(
+def nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
"""
@@ -378,18 +438,19 @@ def _nested_get(
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}: {key}")
return None
+
return d
-def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
+def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
Example:
- >>> _nested_set({}, ["top_key", "second_key"], "value")
+ >>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
- >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
+ >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Args: