Edit on GitHub

sqlglot.executor.table

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot.helper import dict_depth
  6from sqlglot.schema import AbstractMappingSchema
  7
  8
  9class Table:
 10    def __init__(self, columns, rows=None, column_range=None):
 11        self.columns = tuple(columns)
 12        self.column_range = column_range
 13        self.reader = RowReader(self.columns, self.column_range)
 14        self.rows = rows or []
 15        if rows:
 16            assert len(rows[0]) == len(self.columns)
 17        self.range_reader = RangeReader(self)
 18
 19    def add_columns(self, *columns: str) -> None:
 20        self.columns += columns
 21        if self.column_range:
 22            self.column_range = range(
 23                self.column_range.start, self.column_range.stop + len(columns)
 24            )
 25        self.reader = RowReader(self.columns, self.column_range)
 26
 27    def append(self, row):
 28        assert len(row) == len(self.columns)
 29        self.rows.append(row)
 30
 31    def pop(self):
 32        self.rows.pop()
 33
 34    @property
 35    def width(self):
 36        return len(self.columns)
 37
 38    def __len__(self):
 39        return len(self.rows)
 40
 41    def __iter__(self):
 42        return TableIter(self)
 43
 44    def __getitem__(self, index):
 45        self.reader.row = self.rows[index]
 46        return self.reader
 47
 48    def __repr__(self):
 49        columns = tuple(
 50            column
 51            for i, column in enumerate(self.columns)
 52            if not self.column_range or i in self.column_range
 53        )
 54        widths = {column: len(column) for column in columns}
 55        lines = [" ".join(column for column in columns)]
 56
 57        for i, row in enumerate(self):
 58            if i > 10:
 59                break
 60
 61            lines.append(
 62                " ".join(
 63                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
 64                )
 65            )
 66        return "\n".join(lines)
 67
 68
 69class TableIter:
 70    def __init__(self, table):
 71        self.table = table
 72        self.index = -1
 73
 74    def __iter__(self):
 75        return self
 76
 77    def __next__(self):
 78        self.index += 1
 79        if self.index < len(self.table):
 80            return self.table[self.index]
 81        raise StopIteration
 82
 83
 84class RangeReader:
 85    def __init__(self, table):
 86        self.table = table
 87        self.range = range(0)
 88
 89    def __len__(self):
 90        return len(self.range)
 91
 92    def __getitem__(self, column):
 93        return (self.table[i][column] for i in self.range)
 94
 95
 96class RowReader:
 97    def __init__(self, columns, column_range=None):
 98        self.columns = {
 99            column: i for i, column in enumerate(columns) if not column_range or i in column_range
100        }
101        self.row = None
102
103    def __getitem__(self, column):
104        return self.row[self.columns[column]]
105
106
107class Tables(AbstractMappingSchema[Table]):
108    pass
109
110
111def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
112    return Tables(_ensure_tables(d))
113
114
115def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
116    if not d:
117        return {}
118
119    depth = dict_depth(d)
120
121    if depth > 1:
122        return {k: _ensure_tables(v) for k, v in d.items()}
123
124    result = {}
125    for name, table in d.items():
126        if isinstance(table, Table):
127            result[name] = table
128        else:
129            columns = tuple(table[0]) if table else ()
130            rows = [tuple(row[c] for c in columns) for row in table]
131            result[name] = Table(columns=columns, rows=rows)
132
133    return result
class Table:
10class Table:
11    def __init__(self, columns, rows=None, column_range=None):
12        self.columns = tuple(columns)
13        self.column_range = column_range
14        self.reader = RowReader(self.columns, self.column_range)
15        self.rows = rows or []
16        if rows:
17            assert len(rows[0]) == len(self.columns)
18        self.range_reader = RangeReader(self)
19
20    def add_columns(self, *columns: str) -> None:
21        self.columns += columns
22        if self.column_range:
23            self.column_range = range(
24                self.column_range.start, self.column_range.stop + len(columns)
25            )
26        self.reader = RowReader(self.columns, self.column_range)
27
28    def append(self, row):
29        assert len(row) == len(self.columns)
30        self.rows.append(row)
31
32    def pop(self):
33        self.rows.pop()
34
35    @property
36    def width(self):
37        return len(self.columns)
38
39    def __len__(self):
40        return len(self.rows)
41
42    def __iter__(self):
43        return TableIter(self)
44
45    def __getitem__(self, index):
46        self.reader.row = self.rows[index]
47        return self.reader
48
49    def __repr__(self):
50        columns = tuple(
51            column
52            for i, column in enumerate(self.columns)
53            if not self.column_range or i in self.column_range
54        )
55        widths = {column: len(column) for column in columns}
56        lines = [" ".join(column for column in columns)]
57
58        for i, row in enumerate(self):
59            if i > 10:
60                break
61
62            lines.append(
63                " ".join(
64                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
65                )
66            )
67        return "\n".join(lines)
Table(columns, rows=None, column_range=None)
11    def __init__(self, columns, rows=None, column_range=None):
12        self.columns = tuple(columns)
13        self.column_range = column_range
14        self.reader = RowReader(self.columns, self.column_range)
15        self.rows = rows or []
16        if rows:
17            assert len(rows[0]) == len(self.columns)
18        self.range_reader = RangeReader(self)
def add_columns(self, *columns: str) -> None:
20    def add_columns(self, *columns: str) -> None:
21        self.columns += columns
22        if self.column_range:
23            self.column_range = range(
24                self.column_range.start, self.column_range.stop + len(columns)
25            )
26        self.reader = RowReader(self.columns, self.column_range)
def append(self, row):
28    def append(self, row):
29        assert len(row) == len(self.columns)
30        self.rows.append(row)
def pop(self):
32    def pop(self):
33        self.rows.pop()
class TableIter:
70class TableIter:
71    def __init__(self, table):
72        self.table = table
73        self.index = -1
74
75    def __iter__(self):
76        return self
77
78    def __next__(self):
79        self.index += 1
80        if self.index < len(self.table):
81            return self.table[self.index]
82        raise StopIteration
TableIter(table)
71    def __init__(self, table):
72        self.table = table
73        self.index = -1
class RangeReader:
85class RangeReader:
86    def __init__(self, table):
87        self.table = table
88        self.range = range(0)
89
90    def __len__(self):
91        return len(self.range)
92
93    def __getitem__(self, column):
94        return (self.table[i][column] for i in self.range)
RangeReader(table)
86    def __init__(self, table):
87        self.table = table
88        self.range = range(0)
class RowReader:
 97class RowReader:
 98    def __init__(self, columns, column_range=None):
 99        self.columns = {
100            column: i for i, column in enumerate(columns) if not column_range or i in column_range
101        }
102        self.row = None
103
104    def __getitem__(self, column):
105        return self.row[self.columns[column]]
RowReader(columns, column_range=None)
 98    def __init__(self, columns, column_range=None):
 99        self.columns = {
100            column: i for i, column in enumerate(columns) if not column_range or i in column_range
101        }
102        self.row = None
108class Tables(AbstractMappingSchema[Table]):
109    pass

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

def ensure_tables(d: Optional[Dict]) -> sqlglot.executor.table.Tables:
112def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
113    return Tables(_ensure_tables(d))