summaryrefslogtreecommitdiffstats
path: root/tests/test_generators.py
blob: 8aba73f6ded435b5853cab20fc887e0d4eba2448 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from collections import deque
from functools import partial
from typing import List

import pytest

import psycopg
from psycopg import waiting
from psycopg import pq


@pytest.fixture
def pipeline(pgconn):
    nb, pgconn.nonblocking = pgconn.nonblocking, True
    assert pgconn.nonblocking
    pgconn.enter_pipeline_mode()
    yield
    if pgconn.pipeline_status:
        pgconn.exit_pipeline_mode()
    pgconn.nonblocking = nb


def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses):
    actual_statuses: List[pq.ExecStatus] = []
    while len(actual_statuses) != len(expected_statuses):
        if commands:
            gen = generators.pipeline_communicate(pgconn, commands)
            results = waiting.wait(gen, pgconn.socket)
            for (result,) in results:
                actual_statuses.append(result.status)
        else:
            gen = generators.fetch_many(pgconn)
            results = waiting.wait(gen, pgconn.socket)
            for result in results:
                actual_statuses.append(result.status)

    assert actual_statuses == expected_statuses


@pytest.mark.pipeline
def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators):
    commands = deque(
        [
            partial(pgconn.send_query_params, b"select 1", None),
            pgconn.pipeline_sync,
            partial(pgconn.send_query_params, b"select 2", None),
            pgconn.pipeline_sync,
        ]
    )
    expected_statuses = [
        pq.ExecStatus.TUPLES_OK,
        pq.ExecStatus.PIPELINE_SYNC,
        pq.ExecStatus.TUPLES_OK,
        pq.ExecStatus.PIPELINE_SYNC,
    ]
    _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)


@pytest.mark.pipeline
def test_pipeline_communicate_no_sync(pgconn, pipeline, generators):
    numqueries = 10
    commands = deque(
        [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)]
        * numqueries
        + [pgconn.send_flush_request]
    )
    expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries
    _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)


@pytest.fixture
def pipeline_demo(pgconn):
    assert pgconn.pipeline_status == 0
    res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
    res = pgconn.exec_(
        b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)"
    )
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
    yield "pg_pipeline"
    res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message


# TODOCRDB: 1 doesn't get rolled back. Open a ticket?
@pytest.mark.pipeline
@pytest.mark.crdb("skip", reason="pipeline aborted")
def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators):
    insert_sql = b"insert into pg_pipeline(itemno) values ($1)"
    commands = deque(
        [
            partial(pgconn.send_query_params, insert_sql, [b"1"]),
            partial(pgconn.send_query_params, b"select no_such_function(1)", None),
            partial(pgconn.send_query_params, insert_sql, [b"2"]),
            pgconn.pipeline_sync,
            partial(pgconn.send_query_params, insert_sql, [b"3"]),
            pgconn.pipeline_sync,
        ]
    )
    expected_statuses = [
        pq.ExecStatus.COMMAND_OK,
        pq.ExecStatus.FATAL_ERROR,
        pq.ExecStatus.PIPELINE_ABORTED,
        pq.ExecStatus.PIPELINE_SYNC,
        pq.ExecStatus.COMMAND_OK,
        pq.ExecStatus.PIPELINE_SYNC,
    ]
    _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
    pgconn.exit_pipeline_mode()
    res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno")
    assert res.ntuples == 1
    assert res.get_value(0, 0) == b"3"


@pytest.fixture
def pipeline_uniqviol(pgconn):
    if not psycopg.Pipeline.is_supported():
        pytest.skip(psycopg.Pipeline._not_supported_reason())
    assert pgconn.pipeline_status == 0
    res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol")
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
    res = pgconn.exec_(
        b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol("
        b" id bigint primary key, idata bigint)"
    )
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
    res = pgconn.exec_(b"BEGIN")
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
    res = pgconn.prepare(
        b"insertion",
        b"insert into pg_pipeline_uniqviol values ($1, $2) returning id",
    )
    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
    return "pg_pipeline_uniqviol"


def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators):
    commands = deque(
        [
            partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
            partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]),
            partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
            partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]),
            partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]),
            partial(pgconn.send_query_params, b"commit", None),
        ]
    )
    expected_statuses = [
        pq.ExecStatus.TUPLES_OK,
        pq.ExecStatus.TUPLES_OK,
        pq.ExecStatus.FATAL_ERROR,
        pq.ExecStatus.PIPELINE_ABORTED,
        pq.ExecStatus.PIPELINE_ABORTED,
        pq.ExecStatus.PIPELINE_ABORTED,
    ]
    _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)