1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
import os
import sys
import ctypes
from typing import Iterator, List, NamedTuple
from tempfile import TemporaryFile
import pytest
from .utils import check_libpq_version
try:
from psycopg import pq
except ImportError:
pq = None # type: ignore
def pytest_report_header(config):
try:
from psycopg import pq
except ImportError:
return []
return [
f"libpq wrapper implementation: {pq.__impl__}",
f"libpq used: {pq.version()}",
f"libpq compiled: {pq.__build_version__}",
]
def pytest_configure(config):
# register libpq marker
config.addinivalue_line(
"markers",
"libpq(version_expr): run the test only with matching libpq"
" (e.g. '>= 10', '< 9.6')",
)
def pytest_runtest_setup(item):
for m in item.iter_markers(name="libpq"):
assert len(m.args) == 1
msg = check_libpq_version(pq.version(), m.args[0])
if msg:
pytest.skip(msg)
@pytest.fixture
def libpq():
"""Return a ctypes wrapper to access the libpq."""
try:
from psycopg.pq.misc import find_libpq_full_path
# Not available when testing the binary package
libname = find_libpq_full_path()
assert libname, "libpq libname not found"
return ctypes.pydll.LoadLibrary(libname)
except Exception as e:
if pq.__impl__ == "binary":
pytest.skip(f"can't load libpq for testing: {e}")
else:
raise
@pytest.fixture
def setpgenv(monkeypatch):
"""Replace the PG* env vars with the vars provided."""
def setpgenv_(env):
ks = [k for k in os.environ if k.startswith("PG")]
for k in ks:
monkeypatch.delenv(k)
if env:
for k, v in env.items():
monkeypatch.setenv(k, v)
return setpgenv_
@pytest.fixture
def trace(libpq):
pqver = pq.__build_version__
if pqver < 140000:
pytest.skip(f"trace not available on libpq {pqver}")
if sys.platform != "linux":
pytest.skip(f"trace not available on {sys.platform}")
yield Tracer()
class Tracer:
def trace(self, conn):
pgconn: "pq.abc.PGconn"
if hasattr(conn, "exec_"):
pgconn = conn
elif hasattr(conn, "cursor"):
pgconn = conn.pgconn
else:
raise Exception()
return TraceLog(pgconn)
class TraceLog:
def __init__(self, pgconn: "pq.abc.PGconn"):
self.pgconn = pgconn
self.tempfile = TemporaryFile(buffering=0)
pgconn.trace(self.tempfile.fileno())
pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS)
def __del__(self):
if self.pgconn.status == pq.ConnStatus.OK:
self.pgconn.untrace()
self.tempfile.close()
def __iter__(self) -> "Iterator[TraceEntry]":
self.tempfile.seek(0)
data = self.tempfile.read()
for entry in self._parse_entries(data):
yield entry
def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
for line in data.splitlines():
direction, length, type, *content = line.split(b"\t")
yield TraceEntry(
direction=direction.decode(),
length=int(length.decode()),
type=type.decode(),
# Note: the items encoding is not very solid: no escaped
# backslash, no escaped quotes.
# At the moment we don't need a proper parser.
content=[content[0]] if content else [],
)
class TraceEntry(NamedTuple):
direction: str
length: int
type: str
content: List[bytes]
|