summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/schema.py')
-rw-r--r--sqlglot/optimizer/schema.py63
1 files changed, 58 insertions, 5 deletions
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
index 1bbd86a..d7743c9 100644
--- a/sqlglot/optimizer/schema.py
+++ b/sqlglot/optimizer/schema.py
@@ -9,16 +9,28 @@ class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
- def column_names(self, table):
+ def column_names(self, table, only_visible=False):
"""
Get the column names for a table.
-
Args:
table (sqlglot.expressions.Table): Table expression instance
+ only_visible (bool): Whether to include invisible columns
Returns:
list[str]: list of column names
"""
+ @abc.abstractmethod
+ def get_column_type(self, table, column):
+ """
+ Get the exp.DataType type of a column in the schema.
+
+ Args:
+ table (sqlglot.expressions.Table): The source table.
+ column (sqlglot.expressions.Column): The target column.
+ Returns:
+ sqlglot.expressions.DataType.Type: The resulting column type.
+ """
+
class MappingSchema(Schema):
"""
@@ -29,10 +41,19 @@ class MappingSchema(Schema):
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
+ visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
+ are assumed to be visible. The nesting should mirror that of the schema:
+ 1. {table: set(*cols)}}
+ 2. {db: {table: set(*cols)}}}
+ 3. {catalog: {db: {table: set(*cols)}}}}
+ dialect (str): The dialect to be used for custom type mappings.
"""
- def __init__(self, schema):
+ def __init__(self, schema, visible=None, dialect=None):
self.schema = schema
+ self.visible = visible
+ self.dialect = dialect
+ self._type_mapping_cache = {}
depth = _dict_depth(schema)
@@ -49,7 +70,7 @@ class MappingSchema(Schema):
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
- def column_names(self, table):
+ def column_names(self, table, only_visible=False):
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
@@ -58,7 +79,39 @@ class MappingSchema(Schema):
for forbidden in self.forbidden_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
- return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
+
+ columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
+ if not only_visible or not self.visible:
+ return columns
+
+ visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
+ return [col for col in columns if col in visible]
+
+ def get_column_type(self, table, column):
+ try:
+ schema_type = self.schema.get(table.name, {}).get(column.name).upper()
+ return self._convert_type(schema_type)
+ except:
+ raise OptimizeError(f"Failed to get type for column {column.sql()}")
+
+ def _convert_type(self, schema_type):
+ """
+ Convert a type represented as a string to the corresponding exp.DataType.Type object.
+
+ Args:
+ schema_type (str): The type we want to convert.
+ Returns:
+ sqlglot.expressions.DataType.Type: The resulting expression type.
+ """
+ if schema_type not in self._type_mapping_cache:
+ try:
+ self._type_mapping_cache[schema_type] = exp.maybe_parse(
+ schema_type, into=exp.DataType, dialect=self.dialect
+ ).this
+ except AttributeError:
+ raise OptimizeError(f"Failed to convert type {schema_type}")
+
+ return self._type_mapping_cache[schema_type]
def ensure_schema(schema):