summaryrefslogtreecommitdiffstats
path: root/tools/update_oids.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
commit506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch)
tree808913770c5e6935d3714058c2a066c57b4632ec /tools/update_oids.py
parentInitial commit. (diff)
downloadpsycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz
psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tools/update_oids.py')
-rwxr-xr-xtools/update_oids.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/tools/update_oids.py b/tools/update_oids.py
new file mode 100755
index 0000000..df4f969
--- /dev/null
+++ b/tools/update_oids.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python
+"""
+Update the maps of builtin types and names.
+
+This script updates some of the files in psycopg source code with data read
+from a database catalog.
+
+Hint: use docker to upgrade types from a new version in isolation. Run:
+
+ docker run --rm -p 11111:5432 --name pg -e POSTGRES_PASSWORD=password postgres:TAG
+
+with a specified version tag, and then query it using:
+
+ %(prog)s "host=localhost port=11111 user=postgres password=password"
+"""
+
+import re
+import argparse
+import subprocess as sp
+from typing import List
+from pathlib import Path
+from typing_extensions import TypeAlias
+
+import psycopg
+from psycopg.rows import TupleRow
+from psycopg.crdb import CrdbConnection
+
+Connection: TypeAlias = psycopg.Connection[TupleRow]
+
+ROOT = Path(__file__).parent.parent
+
+
+def main() -> None:
+ opt = parse_cmdline()
+ conn = psycopg.connect(opt.dsn, autocommit=True)
+
+ if CrdbConnection.is_crdb(conn):
+ conn = CrdbConnection.connect(opt.dsn, autocommit=True)
+ update_crdb_python_oids(conn)
+ else:
+ update_python_oids(conn)
+ update_cython_oids(conn)
+
+
+def update_python_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg/psycopg/postgres.py"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_py_types(conn))
+ lines.extend(get_py_ranges(conn))
+ lines.extend(get_py_multiranges(conn))
+
+ update_file(fn, lines)
+ sp.check_call(["black", "-q", fn])
+
+
+def update_cython_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_cython_oids(conn))
+
+ update_file(fn, lines)
+
+
+def update_crdb_python_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg/psycopg/crdb/_types.py"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_py_types(conn))
+
+ update_file(fn, lines)
+ sp.check_call(["black", "-q", fn])
+
+
+def get_version_comment(conn: Connection) -> List[str]:
+ if conn.info.vendor == "PostgreSQL":
+ # Assume PG > 10
+ num = conn.info.server_version
+ version = f"{num // 10000}.{num % 100}"
+ elif conn.info.vendor == "CockroachDB":
+ assert isinstance(conn, CrdbConnection)
+ num = conn.info.server_version
+ version = f"{num // 10000}.{num % 10000 // 100}.{num % 100}"
+ else:
+ raise NotImplementedError(f"unexpected vendor: {conn.info.vendor}")
+ return ["", f" # Generated from {conn.info.vendor} {version}", ""]
+
+
+def get_py_types(conn: Connection) -> List[str]:
+ # Note: "record" is a pseudotype but still a useful one to have.
+ # "pg_lsn" is a documented public type and useful in streaming replication
+ lines = []
+ for (typname, oid, typarray, regtype, typdelim) in conn.execute(
+ """
+select typname, oid, typarray,
+ -- CRDB might have quotes in the regtype representation
+ replace(typname::regtype::text, '''', '') as regtype,
+ typdelim
+from pg_type t
+where
+ oid < 10000
+ and oid != '"char"'::regtype
+ and (typtype = 'b' or typname = 'record')
+ and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
+order by typname
+"""
+ ):
+ # Weird legacy type in postgres catalog
+ if typname == "char":
+ typname = regtype = '"char"'
+
+ # https://github.com/cockroachdb/cockroach/issues/81645
+ if typname == "int4" and conn.info.vendor == "CockroachDB":
+ regtype = typname
+
+ params = [f"{typname!r}, {oid}, {typarray}"]
+ if regtype != typname:
+ params.append(f"regtype={regtype!r}")
+ if typdelim != ",":
+ params.append(f"delimiter={typdelim!r}")
+ lines.append(f"TypeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_py_ranges(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid, typarray, rngsubtype) in conn.execute(
+ """
+select typname, oid, typarray, rngsubtype
+from
+ pg_type t
+ join pg_range r on t.oid = rngtypid
+where
+ oid < 10000
+ and typtype = 'r'
+order by typname
+"""
+ ):
+ params = [f"{typname!r}, {oid}, {typarray}, subtype_oid={rngsubtype}"]
+ lines.append(f"RangeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_py_multiranges(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid, typarray, rngtypid, rngsubtype) in conn.execute(
+ """
+select typname, oid, typarray, rngtypid, rngsubtype
+from
+ pg_type t
+ join pg_range r on t.oid = rngmultitypid
+where
+ oid < 10000
+ and typtype = 'm'
+order by typname
+"""
+ ):
+ params = [
+ f"{typname!r}, {oid}, {typarray},"
+ f" range_oid={rngtypid}, subtype_oid={rngsubtype}"
+ ]
+ lines.append(f"MultirangeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_cython_oids(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid) in conn.execute(
+ """
+select typname, oid
+from pg_type
+where
+ oid < 10000
+ and (typtype = any('{b,r,m}') or typname = 'record')
+ and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
+order by typname
+"""
+ ):
+ const_name = typname.upper() + "_OID"
+ lines.append(f" {const_name} = {oid}")
+
+ return lines
+
+
+def update_file(fn: Path, new: List[str]) -> None:
+ with fn.open("r") as f:
+ lines = f.read().splitlines()
+ istart, iend = [
+ i
+ for i, line in enumerate(lines)
+ if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
+ ]
+ lines[istart + 1 : iend] = new
+
+ with fn.open("w") as f:
+ f.write("\n".join(lines))
+ f.write("\n")
+
+
+def parse_cmdline() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+ parser.add_argument("dsn", help="where to connect to")
+ opt = parser.parse_args()
+ return opt
+
+
+if __name__ == "__main__":
+ main()