From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- sqlglot/schema.py | 201 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 131 insertions(+), 70 deletions(-) (limited to 'sqlglot/schema.py') diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 5d60eb9..f1c4a09 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,6 +5,8 @@ import typing as t import sqlglot from sqlglot import expressions as exp +from sqlglot._typing import T +from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie @@ -17,62 +19,83 @@ if t.TYPE_CHECKING: TABLE_ARGS = ("this", "db", "catalog") -T = t.TypeVar("T") - class Schema(abc.ABC): """Abstract base class for database schemas""" + dialect: DialectType + @abc.abstractmethod def add_table( - self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + self, + table: exp.Table | str, + column_mapping: t.Optional[ColumnMapping] = None, + dialect: DialectType = None, ) -> None: """ Register or update a table. Some implementing classes may require column information to also be provided. Args: - table: table expression instance or string representing the table. + table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. + dialect: the SQL dialect that will be used to parse `table` if it's a string. """ @abc.abstractmethod - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + def column_names( + self, + table: exp.Table | str, + only_visible: bool = False, + dialect: DialectType = None, + ) -> t.List[str]: """ Get the column names for a table. Args: table: the `Table` expression instance. only_visible: whether to include invisible columns. + dialect: the SQL dialect that will be used to parse `table` if it's a string. Returns: The list of column names. """ @abc.abstractmethod - def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: + def get_column_type( + self, + table: exp.Table | str, + column: exp.Column, + dialect: DialectType = None, + ) -> exp.DataType: """ - Get the :class:`sqlglot.exp.DataType` type of a column in the schema. + Get the `sqlglot.exp.DataType` type of a column in the schema. Args: table: the source table. column: the target column. + dialect: the SQL dialect that will be used to parse `table` if it's a string. Returns: The resulting column type. """ @property + @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: """ Table arguments this schema support, e.g. `("this", "db", "catalog")` """ - raise NotImplementedError + + @property + def empty(self) -> bool: + """Returns whether or not the schema is empty.""" + return True class AbstractMappingSchema(t.Generic[T]): def __init__( self, - mapping: dict | None = None, + mapping: t.Optional[t.Dict] = None, ) -> None: self.mapping = mapping or {} self.mapping_trie = new_trie( @@ -80,6 +103,10 @@ class AbstractMappingSchema(t.Generic[T]): ) self._supported_table_args: t.Tuple[str, ...] = tuple() + @property + def empty(self) -> bool: + return not self.mapping + def _depth(self) -> int: return dict_depth(self.mapping) @@ -110,8 +137,10 @@ class AbstractMappingSchema(t.Generic[T]): if value == 0: return None - elif value == 1: + + if value == 1: possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) + if len(possibilities) == 1: parts.extend(possibilities[0]) else: @@ -119,12 +148,13 @@ class AbstractMappingSchema(t.Generic[T]): if raise_on_missing: raise SchemaError(f"Ambiguous mapping for {table}: {message}.") return None - return self._nested_get(parts, raise_on_missing=raise_on_missing) - def _nested_get( + return self.nested_get(parts, raise_on_missing=raise_on_missing) + + def nested_get( self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True ) -> t.Optional[t.Any]: - return _nested_get( + return nested_get( d or self.mapping, *zip(self.supported_table_args, reversed(parts)), raise_on_missing=raise_on_missing, @@ -136,17 +166,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): Schema based on a nested mapping. Args: - schema (dict): Mapping in one of the following forms: + schema: Mapping in one of the following forms: 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} 4. None - Tables will be added later - visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema: 1. {table: set(*cols)}} 2. {db: {table: set(*cols)}}} 3. {catalog: {db: {table: set(*cols)}}}} - dialect (str): The dialect to be used for custom type mappings. + dialect: The dialect to be used for custom type mappings & parsing string arguments. + normalize: Whether to normalize identifier names according to the given dialect or not. """ def __init__( @@ -154,10 +185,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): schema: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None, dialect: DialectType = None, + normalize: bool = True, ) -> None: self.dialect = dialect self.visible = visible or {} + self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + super().__init__(self._normalize(schema or {})) @classmethod @@ -179,7 +213,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): ) def add_table( - self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + self, + table: exp.Table | str, + column_mapping: t.Optional[ColumnMapping] = None, + dialect: DialectType = None, ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. @@ -187,10 +224,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): Args: table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. + dialect: the SQL dialect that will be used to parse `table` if it's a string. """ - normalized_table = self._normalize_table(self._ensure_table(table)) + normalized_table = self._normalize_table( + self._ensure_table(table, dialect=dialect), dialect=dialect + ) normalized_column_mapping = { - self._normalize_name(key): value + self._normalize_name(key, dialect=dialect): value for key, value in ensure_column_mapping(column_mapping).items() } @@ -200,38 +240,51 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): parts = self.table_parts(normalized_table) - _nested_set( - self.mapping, - tuple(reversed(parts)), - normalized_column_mapping, - ) + nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) new_trie([parts], self.mapping_trie) - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: - table_ = self._normalize_table(self._ensure_table(table)) - schema = self.find(table_) + def column_names( + self, + table: exp.Table | str, + only_visible: bool = False, + dialect: DialectType = None, + ) -> t.List[str]: + normalized_table = self._normalize_table( + self._ensure_table(table, dialect=dialect), dialect=dialect + ) + schema = self.find(normalized_table) if schema is None: return [] if not only_visible or not self.visible: return list(schema) - visible = self._nested_get(self.table_parts(table_), self.visible) - return [col for col in schema if col in visible] # type: ignore + visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] + return [col for col in schema if col in visible] - def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: - column_name = self._normalize_name(column if isinstance(column, str) else column.this) - table_ = self._normalize_table(self._ensure_table(table)) + def get_column_type( + self, + table: exp.Table | str, + column: exp.Column, + dialect: DialectType = None, + ) -> exp.DataType: + normalized_table = self._normalize_table( + self._ensure_table(table, dialect=dialect), dialect=dialect + ) + normalized_column_name = self._normalize_name( + column if isinstance(column, str) else column.this, dialect=dialect + ) - table_schema = self.find(table_, raise_on_missing=False) + table_schema = self.find(normalized_table, raise_on_missing=False) if table_schema: - column_type = table_schema.get(column_name) + column_type = table_schema.get(normalized_column_name) if isinstance(column_type, exp.DataType): return column_type elif isinstance(column_type, str): - return self._to_data_type(column_type.upper()) + return self._to_data_type(column_type.upper(), dialect=dialect) + raise SchemaError(f"Unknown column type '{column_type}'") return exp.DataType.build("unknown") @@ -250,81 +303,88 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): normalized_mapping: t.Dict = {} for keys in flattened_schema: - columns = _nested_get(schema, *zip(keys, keys)) + columns = nested_get(schema, *zip(keys, keys)) assert columns is not None - normalized_keys = [self._normalize_name(key) for key in keys] + normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] for column_name, column_type in columns.items(): - _nested_set( + nested_set( normalized_mapping, - normalized_keys + [self._normalize_name(column_name)], + normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], column_type, ) return normalized_mapping - def _normalize_table(self, table: exp.Table) -> exp.Table: + def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: normalized_table = table.copy() + for arg in TABLE_ARGS: value = normalized_table.args.get(arg) if isinstance(value, (str, exp.Identifier)): - normalized_table.set(arg, self._normalize_name(value)) + normalized_table.set( + arg, exp.to_identifier(self._normalize_name(value, dialect=dialect)) + ) return normalized_table - def _normalize_name(self, name: str | exp.Identifier) -> str: + def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: + dialect = dialect or self.dialect + try: - identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) + identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) except ParseError: return name if isinstance(name, str) else name.name - return identifier.name if identifier.quoted else identifier.name.lower() + name = identifier.name + + if not self.normalize or identifier.quoted: + return name + + return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 - def _ensure_table(self, table: exp.Table | str) -> exp.Table: - if isinstance(table, exp.Table): - return table - - table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) - if not table_: - raise SchemaError(f"Not a valid table '{table}'") + def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: + return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect) - return table_ - - def _to_data_type(self, schema_type: str) -> exp.DataType: + def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: """ - Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. + Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. Args: schema_type: the type we want to convert. + dialect: the SQL dialect that will be used to parse `schema_type`, if needed. Returns: The resulting expression type. """ if schema_type not in self._type_mapping_cache: + dialect = dialect or self.dialect + try: - expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) - if expression is None: - raise ValueError(f"Could not parse {schema_type}") - self._type_mapping_cache[schema_type] = expression # type: ignore + expression = exp.DataType.build(schema_type, dialect=dialect) + self._type_mapping_cache[schema_type] = expression except AttributeError: - raise SchemaError(f"Failed to convert type {schema_type}") + in_dialect = f" in dialect {dialect}" if dialect else "" + raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") return self._type_mapping_cache[schema_type] -def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: +def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: if isinstance(schema, Schema): return schema - return MappingSchema(schema, dialect=dialect) + return MappingSchema(schema, **kwargs) def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: - if isinstance(mapping, dict): + if mapping is None: + return {} + elif isinstance(mapping, dict): return mapping elif isinstance(mapping, str): col_name_type_strs = [x.strip() for x in mapping.split(",")] @@ -334,11 +394,10 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: } # Check if mapping looks like a DataFrame StructType elif hasattr(mapping, "simpleString"): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} elif isinstance(mapping, list): return {x.strip(): None for x in mapping} - elif mapping is None: - return {} + raise ValueError(f"Invalid mapping provided: {type(mapping)}") @@ -353,10 +412,11 @@ def flatten_schema( tables.extend(flatten_schema(v, depth - 1, keys + [k])) elif depth == 1: tables.append(keys + [k]) + return tables -def _nested_get( +def nested_get( d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True ) -> t.Optional[t.Any]: """ @@ -378,18 +438,19 @@ def _nested_get( name = "table" if name == "this" else name raise ValueError(f"Unknown {name}: {key}") return None + return d -def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: +def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary Example: - >>> _nested_set({}, ["top_key", "second_key"], "value") + >>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}} - >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") + >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} Args: -- cgit v1.2.3