summaryrefslogtreecommitdiffstats
path: root/pgcli/packages/formatter/sqlformatter.py
blob: 5224eff6382d94d1f80ec488281b895d47f04a2f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# coding=utf-8

from pgcli.packages.parseutils.tables import extract_tables


supported_formats = (
    "sql-insert",
    "sql-update",
    "sql-update-1",
    "sql-update-2",
)

preprocessors = ()


def escape_for_sql_statement(value):
    if value is None:
        return "NULL"

    if isinstance(value, bytes):
        return f"X'{value.hex()}'"

    return "'{}'".format(value)


def adapter(data, headers, table_format=None, **kwargs):
    tables = extract_tables(formatter.query)
    if len(tables) > 0:
        table = tables[0]
        if table[0]:
            table_name = "{}.{}".format(*table[:2])
        else:
            table_name = table[1]
    else:
        table_name = "DUAL"
    if table_format == "sql-insert":
        h = '", "'.join(headers)
        yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h)
        prefix = "  "
        for d in data:
            values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
            yield "{}({})".format(prefix, values)
            if prefix == "  ":
                prefix = ", "
        yield ";"
    if table_format.startswith("sql-update"):
        s = table_format.split("-")
        keys = 1
        if len(s) > 2:
            keys = int(s[-1])
        for d in data:
            yield 'UPDATE "{}" SET'.format(table_name)
            prefix = "  "
            for i, v in enumerate(d[keys:], keys):
                yield '{}"{}" = {}'.format(
                    prefix, headers[i], escape_for_sql_statement(v)
                )
                if prefix == "  ":
                    prefix = ", "
            f = '"{}" = {}'
            where = (
                f.format(headers[i], escape_for_sql_statement(d[i]))
                for i in range(keys)
            )
            yield "WHERE {};".format(" AND ".join(where))


def register_new_formatter(TabularOutputFormatter):
    global formatter
    formatter = TabularOutputFormatter
    for sql_format in supported_formats:
        TabularOutputFormatter.register_new_formatter(
            sql_format, adapter, preprocessors, {"table_format": sql_format}
        )