summaryrefslogtreecommitdiffstats
path: root/tests/test_server_cursor.py
blob: f7b6c8ed63cdb093ee3cdc17cdc33cb82da1a2bb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
import pytest

import psycopg
from psycopg import rows, errors as e
from psycopg.pq import Format

pytestmark = pytest.mark.crdb_skip("server-side cursor")


def test_init_row_factory(conn):
    with psycopg.ServerCursor(conn, "foo") as cur:
        assert cur.name == "foo"
        assert cur.connection is conn
        assert cur.row_factory is conn.row_factory

    conn.row_factory = rows.dict_row

    with psycopg.ServerCursor(conn, "bar") as cur:
        assert cur.name == "bar"
        assert cur.row_factory is rows.dict_row  # type: ignore

    with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur:
        assert cur.name == "baz"
        assert cur.row_factory is rows.namedtuple_row  # type: ignore


def test_init_params(conn):
    with psycopg.ServerCursor(conn, "foo") as cur:
        assert cur.scrollable is None
        assert cur.withhold is False

    with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur:
        assert cur.scrollable is False
        assert cur.withhold is True


@pytest.mark.crdb_skip("cursor invalid name")
def test_funny_name(conn):
    cur = conn.cursor("1-2-3")
    cur.execute("select generate_series(1, 3) as bar")
    assert cur.fetchall() == [(1,), (2,), (3,)]
    assert cur.name == "1-2-3"
    cur.close()


def test_repr(conn):
    cur = conn.cursor("my-name")
    assert "psycopg.ServerCursor" in str(cur)
    assert "my-name" in repr(cur)
    cur.close()


def test_connection(conn):
    cur = conn.cursor("foo")
    assert cur.connection is conn
    cur.close()


def test_description(conn):
    cur = conn.cursor("foo")
    assert cur.name == "foo"
    cur.execute("select generate_series(1, 10)::int4 as bar")
    assert len(cur.description) == 1
    assert cur.description[0].name == "bar"
    assert cur.description[0].type_code == cur.adapters.types["int4"].oid
    assert cur.pgresult.ntuples == 0
    cur.close()


def test_format(conn):
    cur = conn.cursor("foo")
    assert cur.format == Format.TEXT
    cur.close()

    cur = conn.cursor("foo", binary=True)
    assert cur.format == Format.BINARY
    cur.close()


def test_query_params(conn):
    with conn.cursor("foo") as cur:
        assert cur._query is None
        cur.execute("select generate_series(1, %s) as bar", (3,))
        assert cur._query
        assert b"declare" in cur._query.query.lower()
        assert b"(1, $1)" in cur._query.query.lower()
        assert cur._query.params == [bytes([0, 3])]  # 3 as binary int2


def test_binary_cursor_execute(conn):
    cur = conn.cursor("foo", binary=True)
    cur.execute("select generate_series(1, 2)::int4")
    assert cur.fetchone() == (1,)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
    assert cur.fetchone() == (2,)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
    cur.close()


def test_execute_binary(conn):
    cur = conn.cursor("foo")
    cur.execute("select generate_series(1, 2)::int4", binary=True)
    assert cur.fetchone() == (1,)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
    assert cur.fetchone() == (2,)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"

    cur.execute("select generate_series(1, 1)::int4")
    assert cur.fetchone() == (1,)
    assert cur.pgresult.fformat(0) == 0
    assert cur.pgresult.get_value(0, 0) == b"1"
    cur.close()


def test_binary_cursor_text_override(conn):
    cur = conn.cursor("foo", binary=True)
    cur.execute("select generate_series(1, 2)", binary=False)
    assert cur.fetchone() == (1,)
    assert cur.pgresult.fformat(0) == 0
    assert cur.pgresult.get_value(0, 0) == b"1"
    assert cur.fetchone() == (2,)
    assert cur.pgresult.fformat(0) == 0
    assert cur.pgresult.get_value(0, 0) == b"2"

    cur.execute("select generate_series(1, 2)::int4")
    assert cur.fetchone() == (1,)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
    cur.close()


def test_close(conn, recwarn):
    if conn.info.transaction_status == conn.TransactionStatus.INTRANS:
        # connection dirty from previous failure
        conn.execute("close foo")
    recwarn.clear()
    cur = conn.cursor("foo")
    cur.execute("select generate_series(1, 10) as bar")
    cur.close()
    assert cur.closed

    assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
    del cur
    assert not recwarn, [str(w.message) for w in recwarn.list]


def test_close_idempotent(conn):
    cur = conn.cursor("foo")
    cur.execute("select 1")
    cur.fetchall()
    cur.close()
    cur.close()


def test_close_broken_conn(conn):
    cur = conn.cursor("foo")
    conn.close()
    cur.close()
    assert cur.closed


def test_cursor_close_fetchone(conn):
    cur = conn.cursor("foo")
    assert not cur.closed

    query = "select * from generate_series(1, 10)"
    cur.execute(query)
    for _ in range(5):
        cur.fetchone()

    cur.close()
    assert cur.closed

    with pytest.raises(e.InterfaceError):
        cur.fetchone()


def test_cursor_close_fetchmany(conn):
    cur = conn.cursor("foo")
    assert not cur.closed

    query = "select * from generate_series(1, 10)"
    cur.execute(query)
    assert len(cur.fetchmany(2)) == 2

    cur.close()
    assert cur.closed

    with pytest.raises(e.InterfaceError):
        cur.fetchmany(2)


def test_cursor_close_fetchall(conn):
    cur = conn.cursor("foo")
    assert not cur.closed

    query = "select * from generate_series(1, 10)"
    cur.execute(query)
    assert len(cur.fetchall()) == 10

    cur.close()
    assert cur.closed

    with pytest.raises(e.InterfaceError):
        cur.fetchall()


def test_close_noop(conn, recwarn):
    recwarn.clear()
    cur = conn.cursor("foo")
    cur.close()
    assert not recwarn, [str(w.message) for w in recwarn.list]


def test_close_on_error(conn):
    cur = conn.cursor("foo")
    cur.execute("select 1")
    with pytest.raises(e.ProgrammingError):
        conn.execute("wat")
    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
    cur.close()


def test_pgresult(conn):
    cur = conn.cursor()
    cur.execute("select 1")
    assert cur.pgresult
    cur.close()
    assert not cur.pgresult


def test_context(conn, recwarn):
    recwarn.clear()
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, 10) as bar")

    assert cur.closed
    assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
    del cur
    assert not recwarn, [str(w.message) for w in recwarn.list]


def test_close_no_clobber(conn):
    with pytest.raises(e.DivisionByZero):
        with conn.cursor("foo") as cur:
            cur.execute("select 1 / %s", (0,))
            cur.fetchall()


def test_warn_close(conn, recwarn):
    recwarn.clear()
    cur = conn.cursor("foo")
    cur.execute("select generate_series(1, 10) as bar")
    del cur
    assert ".close()" in str(recwarn.pop(ResourceWarning).message)


def test_execute_reuse(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as foo", (3,))
        assert cur.fetchone() == (1,)

        cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
        assert cur.fetchone() == ("hello", "world")
        assert cur.description[0].name == "bar"
        assert cur.description[0].type_code == cur.adapters.types["text"].oid
        assert cur.description[1].name == "baz"


@pytest.mark.parametrize(
    "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
)
def test_execute_error(conn, stmt):
    cur = conn.cursor("foo")
    with pytest.raises(e.ProgrammingError):
        cur.execute(stmt)
    cur.close()


def test_executemany(conn):
    cur = conn.cursor("foo")
    with pytest.raises(e.NotSupportedError):
        cur.executemany("select %s", [(1,), (2,)])
    cur.close()


def test_fetchone(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (2,))
        assert cur.fetchone() == (1,)
        assert cur.fetchone() == (2,)
        assert cur.fetchone() is None


def test_fetchmany(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (5,))
        assert cur.fetchmany(3) == [(1,), (2,), (3,)]
        assert cur.fetchone() == (4,)
        assert cur.fetchmany(3) == [(5,)]
        assert cur.fetchmany(3) == []


def test_fetchall(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (3,))
        assert cur.fetchall() == [(1,), (2,), (3,)]
        assert cur.fetchall() == []

    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (3,))
        assert cur.fetchone() == (1,)
        assert cur.fetchall() == [(2,), (3,)]
        assert cur.fetchall() == []


def test_nextset(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (3,))
        assert not cur.nextset()


def test_no_result(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar where false", (3,))
        assert len(cur.description) == 1
        assert cur.fetchall() == []


@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
def test_standard_row_factory(conn, row_factory):
    if row_factory == "tuple_row":
        getter = lambda r: r[0]  # noqa: E731
    elif row_factory == "dict_row":
        getter = lambda r: r["bar"]  # noqa: E731
    elif row_factory == "namedtuple_row":
        getter = lambda r: r.bar  # noqa: E731
    else:
        assert False, row_factory

    row_factory = getattr(rows, row_factory)
    with conn.cursor("foo", row_factory=row_factory) as cur:
        cur.execute("select generate_series(1, 5) as bar")
        assert getter(cur.fetchone()) == 1
        assert list(map(getter, cur.fetchmany(2))) == [2, 3]
        assert list(map(getter, cur.fetchall())) == [4, 5]


@pytest.mark.crdb_skip("scroll cursor")
def test_row_factory(conn):
    n = 0

    def my_row_factory(cur):
        nonlocal n
        n += 1
        return lambda values: [n] + [-v for v in values]

    cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True)
    cur.execute("select generate_series(1, 3) as x")
    recs = cur.fetchall()
    cur.scroll(0, "absolute")
    while True:
        rec = cur.fetchone()
        if not rec:
            break
        recs.append(rec)
    assert recs == [[1, -1], [1, -2], [1, -3]] * 2

    cur.scroll(0, "absolute")
    cur.row_factory = rows.dict_row
    assert cur.fetchone() == {"x": 1}
    cur.close()


def test_rownumber(conn):
    cur = conn.cursor("foo")
    assert cur.rownumber is None

    cur.execute("select 1 from generate_series(1, 42)")
    assert cur.rownumber == 0

    cur.fetchone()
    assert cur.rownumber == 1
    cur.fetchone()
    assert cur.rownumber == 2
    cur.fetchmany(10)
    assert cur.rownumber == 12
    cur.fetchall()
    assert cur.rownumber == 42
    cur.close()


def test_iter(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (3,))
        recs = list(cur)
    assert recs == [(1,), (2,), (3,)]

    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (3,))
        assert cur.fetchone() == (1,)
        recs = list(cur)
    assert recs == [(2,), (3,)]


def test_iter_rownumber(conn):
    with conn.cursor("foo") as cur:
        cur.execute("select generate_series(1, %s) as bar", (3,))
        for row in cur:
            assert cur.rownumber == row[0]


def test_itersize(conn, commands):
    with conn.cursor("foo") as cur:
        assert cur.itersize == 100
        cur.itersize = 2
        cur.execute("select generate_series(1, %s) as bar", (3,))
        commands.popall()  # flush begin and other noise

        list(cur)
        cmds = commands.popall()
        assert len(cmds) == 2
        for cmd in cmds:
            assert "fetch forward 2" in cmd.lower()


def test_cant_scroll_by_default(conn):
    cur = conn.cursor("tmp")
    assert cur.scrollable is None
    with pytest.raises(e.ProgrammingError):
        cur.scroll(0)
    cur.close()


@pytest.mark.crdb_skip("scroll cursor")
def test_scroll(conn):
    cur = conn.cursor("tmp", scrollable=True)
    cur.execute("select generate_series(0,9)")
    cur.scroll(2)
    assert cur.fetchone() == (2,)
    cur.scroll(2)
    assert cur.fetchone() == (5,)
    cur.scroll(2, mode="relative")
    assert cur.fetchone() == (8,)
    cur.scroll(9, mode="absolute")
    assert cur.fetchone() == (9,)

    with pytest.raises(ValueError):
        cur.scroll(9, mode="wat")
    cur.close()


@pytest.mark.crdb_skip("scroll cursor")
def test_scrollable(conn):
    curs = conn.cursor("foo", scrollable=True)
    assert curs.scrollable is True
    curs.execute("select generate_series(0, 5)")
    curs.scroll(5)
    for i in range(4, -1, -1):
        curs.scroll(-1)
        assert i == curs.fetchone()[0]
        curs.scroll(-1)
    curs.close()


def test_non_scrollable(conn):
    curs = conn.cursor("foo", scrollable=False)
    assert curs.scrollable is False
    curs.execute("select generate_series(0, 5)")
    curs.scroll(5)
    with pytest.raises(e.OperationalError):
        curs.scroll(-1)
    curs.close()


@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
def test_no_hold(conn, kwargs):
    with conn.cursor("foo", **kwargs) as curs:
        assert curs.withhold is False
        curs.execute("select generate_series(0, 2)")
        assert curs.fetchone() == (0,)
        conn.commit()
        with pytest.raises(e.InvalidCursorName):
            curs.fetchone()


@pytest.mark.crdb_skip("cursor with hold")
def test_hold(conn):
    with conn.cursor("foo", withhold=True) as curs:
        assert curs.withhold is True
        curs.execute("select generate_series(0, 5)")
        assert curs.fetchone() == (0,)
        conn.commit()
        assert curs.fetchone() == (1,)


@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
def test_steal_cursor(conn, row_factory):
    cur1 = conn.cursor()
    cur1.execute("declare test cursor for select generate_series(1, 6) as s")

    cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory))
    # can call fetch without execute
    rec = cur2.fetchone()
    assert rec == (1,)
    if row_factory == "namedtuple_row":
        assert rec.s == 1
    assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
    assert cur2.fetchall() == [(5,), (6,)]
    cur2.close()


def test_stolen_cursor_close(conn):
    cur1 = conn.cursor()
    cur1.execute("declare test cursor for select generate_series(1, 6)")
    cur2 = conn.cursor("test")
    cur2.close()

    cur1.execute("declare test cursor for select generate_series(1, 6)")
    cur2 = conn.cursor("test")
    cur2.close()