From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/schema.py | 129 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 sqlglot/optimizer/schema.py (limited to 'sqlglot/optimizer/schema.py') diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py new file mode 100644 index 0000000..9968108 --- /dev/null +++ b/sqlglot/optimizer/schema.py @@ -0,0 +1,129 @@ +import abc + +from sqlglot import exp +from sqlglot.errors import OptimizeError +from sqlglot.helper import csv_reader + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @abc.abstractmethod + def column_names(self, table): + """ + Get the column names for a table. + + Args: + table (sqlglot.expressions.Table): Table expression instance + Returns: + list[str]: list of column names + """ + + +class MappingSchema(Schema): + """ + Schema based on a nested mapping. + + Args: + schema (dict): Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + """ + + def __init__(self, schema): + self.schema = schema + + depth = _dict_depth(schema) + + if not depth: # {} + self.supported_table_args = [] + elif depth == 2: # {table: {col: type}} + self.supported_table_args = ("this",) + elif depth == 3: # {db: {table: {col: type}}} + self.supported_table_args = ("db", "this") + elif depth == 4: # {catalog: {db: {table: {col: type}}}} + self.supported_table_args = ("catalog", "db", "this") + else: + raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + + self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) + + def column_names(self, table): + if not isinstance(table.this, exp.Identifier): + return fs_get(table) + + args = tuple(table.text(p) for p in self.supported_table_args) + + for forbidden in self.forbidden_args: + if table.text(forbidden): + raise ValueError( + f"Schema doesn't support {forbidden}. Received: {table.sql()}" + ) + return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + +def ensure_schema(schema): + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema) + + +def fs_get(table): + name = table.this.name.upper() + + if name.upper() == "READ_CSV": + with csv_reader(table) as reader: + return next(reader) + + raise ValueError(f"Cannot read schema for {table}") + + +def _nested_get(d, *path): + """ + Get a value for a nested dictionary. + + Args: + d (dict): dictionary + *path (tuple[str, str]): tuples of (name, key) + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + """ + for name, key in path: + d = d.get(key) + if d is None: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}") + return d + + +def _dict_depth(d): + """ + Get the nesting depth of a dictionary. + + For example: + >>> _dict_depth(None) + 0 + >>> _dict_depth({}) + 1 + >>> _dict_depth({"a": "b"}) + 1 + >>> _dict_depth({"a": {}}) + 2 + >>> _dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + _dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 -- cgit v1.2.3