summaryrefslogtreecommitdiffstats
path: root/sqlglot/executor/table.py
blob: f1b5b5428bebd6932ce8c187ad6b0f1154709730 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema


class Table:
    def __init__(self, columns, rows=None, column_range=None):
        self.columns = tuple(columns)
        self.column_range = column_range
        self.reader = RowReader(self.columns, self.column_range)
        self.rows = rows or []
        if rows:
            assert len(rows[0]) == len(self.columns)
        self.range_reader = RangeReader(self)

    def add_columns(self, *columns: str) -> None:
        self.columns += columns
        if self.column_range:
            self.column_range = range(
                self.column_range.start, self.column_range.stop + len(columns)
            )
        self.reader = RowReader(self.columns, self.column_range)

    def append(self, row):
        assert len(row) == len(self.columns)
        self.rows.append(row)

    def pop(self):
        self.rows.pop()

    @property
    def width(self):
        return len(self.columns)

    def __len__(self):
        return len(self.rows)

    def __iter__(self):
        return TableIter(self)

    def __getitem__(self, index):
        self.reader.row = self.rows[index]
        return self.reader

    def __repr__(self):
        columns = tuple(
            column
            for i, column in enumerate(self.columns)
            if not self.column_range or i in self.column_range
        )
        widths = {column: len(column) for column in columns}
        lines = [" ".join(column for column in columns)]

        for i, row in enumerate(self):
            if i > 10:
                break

            lines.append(
                " ".join(
                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
                )
            )
        return "\n".join(lines)


class TableIter:
    def __init__(self, table):
        self.table = table
        self.index = -1

    def __iter__(self):
        return self

    def __next__(self):
        self.index += 1
        if self.index < len(self.table):
            return self.table[self.index]
        raise StopIteration


class RangeReader:
    def __init__(self, table):
        self.table = table
        self.range = range(0)

    def __len__(self):
        return len(self.range)

    def __getitem__(self, column):
        return (self.table[i][column] for i in self.range)


class RowReader:
    def __init__(self, columns, column_range=None):
        self.columns = {
            column: i for i, column in enumerate(columns) if not column_range or i in column_range
        }
        self.row = None

    def __getitem__(self, column):
        return self.row[self.columns[column]]


class Tables(AbstractMappingSchema[Table]):
    pass


def ensure_tables(d: dict | None) -> Tables:
    return Tables(_ensure_tables(d))


def _ensure_tables(d: dict | None) -> dict:
    if not d:
        return {}

    depth = dict_depth(d)

    if depth > 1:
        return {k: _ensure_tables(v) for k, v in d.items()}

    result = {}
    for name, table in d.items():
        if isinstance(table, Table):
            result[name] = table
        else:
            columns = tuple(table[0]) if table else ()
            rows = [tuple(row[c] for c in columns) for row in table]
            result[name] = Table(columns=columns, rows=rows)
    return result