summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/schema.py
blob: 996810806273b9c358490baaef3cdd26f023820a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import abc

from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import csv_reader


class Schema(abc.ABC):
    """Abstract base class for database schemas"""

    @abc.abstractmethod
    def column_names(self, table):
        """
        Get the column names for a table.

        Args:
            table (sqlglot.expressions.Table): Table expression instance
        Returns:
            list[str]: list of column names
        """


class MappingSchema(Schema):
    """
    Schema based on a nested mapping.

    Args:
        schema (dict): Mapping in one of the following forms:
            1. {table: {col: type}}
            2. {db: {table: {col: type}}}
            3. {catalog: {db: {table: {col: type}}}}
    """

    def __init__(self, schema):
        self.schema = schema

        depth = _dict_depth(schema)

        if not depth:  # {}
            self.supported_table_args = []
        elif depth == 2:  # {table: {col: type}}
            self.supported_table_args = ("this",)
        elif depth == 3:  # {db: {table: {col: type}}}
            self.supported_table_args = ("db", "this")
        elif depth == 4:  # {catalog: {db: {table: {col: type}}}}
            self.supported_table_args = ("catalog", "db", "this")
        else:
            raise OptimizeError(f"Invalid schema shape. Depth: {depth}")

        self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)

    def column_names(self, table):
        if not isinstance(table.this, exp.Identifier):
            return fs_get(table)

        args = tuple(table.text(p) for p in self.supported_table_args)

        for forbidden in self.forbidden_args:
            if table.text(forbidden):
                raise ValueError(
                    f"Schema doesn't support {forbidden}. Received: {table.sql()}"
                )
        return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))


def ensure_schema(schema):
    if isinstance(schema, Schema):
        return schema

    return MappingSchema(schema)


def fs_get(table):
    name = table.this.name.upper()

    if name.upper() == "READ_CSV":
        with csv_reader(table) as reader:
            return next(reader)

    raise ValueError(f"Cannot read schema for {table}")


def _nested_get(d, *path):
    """
    Get a value for a nested dictionary.

    Args:
        d (dict): dictionary
        *path (tuple[str, str]): tuples of (name, key)
            `key` is the key in the dictionary to get.
            `name` is a string to use in the error if `key` isn't found.
    """
    for name, key in path:
        d = d.get(key)
        if d is None:
            name = "table" if name == "this" else name
            raise ValueError(f"Unknown {name}")
    return d


def _dict_depth(d):
    """
    Get the nesting depth of a dictionary.

    For example:
        >>> _dict_depth(None)
        0
        >>> _dict_depth({})
        1
        >>> _dict_depth({"a": "b"})
        1
        >>> _dict_depth({"a": {}})
        2
        >>> _dict_depth({"a": {"b": {}}})
        3

    Args:
        d (dict): dictionary
    Returns:
        int: depth
    """
    try:
        return 1 + _dict_depth(next(iter(d.values())))
    except AttributeError:
        # d doesn't have attribute "values"
        return 0
    except StopIteration:
        # d.values() returns an empty sequence
        return 1