summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/schema.py')
-rw-r--r--sqlglot/optimizer/schema.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
new file mode 100644
index 0000000..9968108
--- /dev/null
+++ b/sqlglot/optimizer/schema.py
@@ -0,0 +1,129 @@
+import abc
+
+from sqlglot import exp
+from sqlglot.errors import OptimizeError
+from sqlglot.helper import csv_reader
+
+
+class Schema(abc.ABC):
+ """Abstract base class for database schemas"""
+
+ @abc.abstractmethod
+ def column_names(self, table):
+ """
+ Get the column names for a table.
+
+ Args:
+ table (sqlglot.expressions.Table): Table expression instance
+ Returns:
+ list[str]: list of column names
+ """
+
+
+class MappingSchema(Schema):
+ """
+ Schema based on a nested mapping.
+
+ Args:
+ schema (dict): Mapping in one of the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ """
+
+ def __init__(self, schema):
+ self.schema = schema
+
+ depth = _dict_depth(schema)
+
+ if not depth: # {}
+ self.supported_table_args = []
+ elif depth == 2: # {table: {col: type}}
+ self.supported_table_args = ("this",)
+ elif depth == 3: # {db: {table: {col: type}}}
+ self.supported_table_args = ("db", "this")
+ elif depth == 4: # {catalog: {db: {table: {col: type}}}}
+ self.supported_table_args = ("catalog", "db", "this")
+ else:
+ raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
+
+ self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
+
+ def column_names(self, table):
+ if not isinstance(table.this, exp.Identifier):
+ return fs_get(table)
+
+ args = tuple(table.text(p) for p in self.supported_table_args)
+
+ for forbidden in self.forbidden_args:
+ if table.text(forbidden):
+ raise ValueError(
+ f"Schema doesn't support {forbidden}. Received: {table.sql()}"
+ )
+ return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
+
+
+def ensure_schema(schema):
+ if isinstance(schema, Schema):
+ return schema
+
+ return MappingSchema(schema)
+
+
+def fs_get(table):
+ name = table.this.name.upper()
+
+ if name.upper() == "READ_CSV":
+ with csv_reader(table) as reader:
+ return next(reader)
+
+ raise ValueError(f"Cannot read schema for {table}")
+
+
+def _nested_get(d, *path):
+ """
+ Get a value for a nested dictionary.
+
+ Args:
+ d (dict): dictionary
+ *path (tuple[str, str]): tuples of (name, key)
+ `key` is the key in the dictionary to get.
+ `name` is a string to use in the error if `key` isn't found.
+ """
+ for name, key in path:
+ d = d.get(key)
+ if d is None:
+ name = "table" if name == "this" else name
+ raise ValueError(f"Unknown {name}")
+ return d
+
+
+def _dict_depth(d):
+ """
+ Get the nesting depth of a dictionary.
+
+ For example:
+ >>> _dict_depth(None)
+ 0
+ >>> _dict_depth({})
+ 1
+ >>> _dict_depth({"a": "b"})
+ 1
+ >>> _dict_depth({"a": {}})
+ 2
+ >>> _dict_depth({"a": {"b": {}}})
+ 3
+
+ Args:
+ d (dict): dictionary
+ Returns:
+ int: depth
+ """
+ try:
+ return 1 + _dict_depth(next(iter(d.values())))
+ except AttributeError:
+ # d doesn't have attribute "values"
+ return 0
+ except StopIteration:
+ # d.values() returns an empty sequence
+ return 1