summaryrefslogtreecommitdiffstats
path: root/tests/fix_pq.py
blob: 6811a26c32a3337bf0d1364684be9c06c82ee3a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import sys
import ctypes
from typing import Iterator, List, NamedTuple
from tempfile import TemporaryFile

import pytest

from .utils import check_libpq_version

try:
    from psycopg import pq
except ImportError:
    pq = None  # type: ignore


def pytest_report_header(config):
    try:
        from psycopg import pq
    except ImportError:
        return []

    return [
        f"libpq wrapper implementation: {pq.__impl__}",
        f"libpq used: {pq.version()}",
        f"libpq compiled: {pq.__build_version__}",
    ]


def pytest_configure(config):
    # register libpq marker
    config.addinivalue_line(
        "markers",
        "libpq(version_expr): run the test only with matching libpq"
        " (e.g. '>= 10', '< 9.6')",
    )


def pytest_runtest_setup(item):
    for m in item.iter_markers(name="libpq"):
        assert len(m.args) == 1
        msg = check_libpq_version(pq.version(), m.args[0])
        if msg:
            pytest.skip(msg)


@pytest.fixture
def libpq():
    """Return a ctypes wrapper to access the libpq."""
    try:
        from psycopg.pq.misc import find_libpq_full_path

        # Not available when testing the binary package
        libname = find_libpq_full_path()
        assert libname, "libpq libname not found"
        return ctypes.pydll.LoadLibrary(libname)
    except Exception as e:
        if pq.__impl__ == "binary":
            pytest.skip(f"can't load libpq for testing: {e}")
        else:
            raise


@pytest.fixture
def setpgenv(monkeypatch):
    """Replace the PG* env vars with the vars provided."""

    def setpgenv_(env):
        ks = [k for k in os.environ if k.startswith("PG")]
        for k in ks:
            monkeypatch.delenv(k)

        if env:
            for k, v in env.items():
                monkeypatch.setenv(k, v)

    return setpgenv_


@pytest.fixture
def trace(libpq):
    pqver = pq.__build_version__
    if pqver < 140000:
        pytest.skip(f"trace not available on libpq {pqver}")
    if sys.platform != "linux":
        pytest.skip(f"trace not available on {sys.platform}")

    yield Tracer()


class Tracer:
    def trace(self, conn):
        pgconn: "pq.abc.PGconn"

        if hasattr(conn, "exec_"):
            pgconn = conn
        elif hasattr(conn, "cursor"):
            pgconn = conn.pgconn
        else:
            raise Exception()

        return TraceLog(pgconn)


class TraceLog:
    def __init__(self, pgconn: "pq.abc.PGconn"):
        self.pgconn = pgconn
        self.tempfile = TemporaryFile(buffering=0)
        pgconn.trace(self.tempfile.fileno())
        pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS)

    def __del__(self):
        if self.pgconn.status == pq.ConnStatus.OK:
            self.pgconn.untrace()
        self.tempfile.close()

    def __iter__(self) -> "Iterator[TraceEntry]":
        self.tempfile.seek(0)
        data = self.tempfile.read()
        for entry in self._parse_entries(data):
            yield entry

    def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
        for line in data.splitlines():
            direction, length, type, *content = line.split(b"\t")
            yield TraceEntry(
                direction=direction.decode(),
                length=int(length.decode()),
                type=type.decode(),
                # Note: the items encoding is not very solid: no escaped
                # backslash, no escaped quotes.
                # At the moment we don't need a proper parser.
                content=[content[0]] if content else [],
            )


class TraceEntry(NamedTuple):
    direction: str
    length: int
    type: str
    content: List[bytes]