summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
blob: c4dd91ee910c78cc1fb5ddb96dd5d4b6b65ab40d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import inspect
import logging
import re
import sys
from contextlib import contextmanager
from enum import Enum

CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
logger = logging.getLogger("sqlglot")


class AutoName(Enum):
    def _generate_next_value_(name, _start, _count, _last_values):
        return name


def list_get(arr, index):
    try:
        return arr[index]
    except IndexError:
        return None


def ensure_list(value):
    if value is None:
        return []
    return value if isinstance(value, (list, tuple, set)) else [value]


def csv(*args, sep=", "):
    return sep.join(arg for arg in args if arg)


def subclasses(module_name, classes, exclude=()):
    """
    Returns a list of all subclasses for a specified class set, posibly excluding some of them.

    Args:
        module_name (str): The name of the module to search for subclasses in.
        classes (type|tuple[type]): Class(es) we want to find the subclasses of.
        exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
    Returns:
        A list of all the target subclasses.
    """
    return [
        obj
        for _, obj in inspect.getmembers(
            sys.modules[module_name],
            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
        )
    ]


def apply_index_offset(expressions, offset):
    if not offset or len(expressions) != 1:
        return expressions

    expression = expressions[0]

    if expression.is_int:
        expression = expression.copy()
        logger.warning("Applying array index offset (%s)", offset)
        expression.args["this"] = str(int(expression.args["this"]) + offset)
        return [expression]
    return expressions


def camel_to_snake_case(name):
    return CAMEL_CASE_PATTERN.sub("_", name).upper()


def while_changing(expression, func):
    while True:
        start = hash(expression)
        expression = func(expression)
        if start == hash(expression):
            break
    return expression


def tsort(dag):
    result = []

    def visit(node, visited):
        if node in result:
            return
        if node in visited:
            raise ValueError("Cycle error")

        visited.add(node)

        for dep in dag.get(node, []):
            visit(dep, visited)

        visited.remove(node)
        result.append(node)

    for node in dag:
        visit(node, set())

    return result


def open_file(file_name):
    """
    Open a file that may be compressed as gzip and return in newline mode.
    """
    with open(file_name, "rb") as f:
        gzipped = f.read(2) == b"\x1f\x8b"

    if gzipped:
        import gzip

        return gzip.open(file_name, "rt", newline="")

    return open(file_name, "rt", encoding="utf-8", newline="")


@contextmanager
def csv_reader(table):
    """
    Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])

    Args:
        table (exp.Table): A table expression with an anonymous function READ_CSV in it

    Returns:
        A python csv reader.
    """
    file, *args = table.this.expressions
    file = file.name
    file = open_file(file)

    delimiter = ","
    args = iter(arg.name for arg in args)
    for k, v in zip(args, args):
        if k == "delimiter":
            delimiter = v

    try:
        import csv as csv_

        yield csv_.reader(file, delimiter=delimiter)
    finally:
        file.close()


def find_new_name(taken, base):
    """
    Searches for a new name.

    Args:
        taken (Sequence[str]): set of taken names
        base (str): base name to alter
    """
    if base not in taken:
        return base

    i = 2
    new = f"{base}_{i}"
    while new in taken:
        i += 1
        new = f"{base}_{i}"
    return new