diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
commit | 506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch) | |
tree | 808913770c5e6935d3714058c2a066c57b4632ec /tools/update_oids.py | |
parent | Initial commit. (diff) | |
download | psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip |
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tools/update_oids.py')
-rwxr-xr-x | tools/update_oids.py | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/tools/update_oids.py b/tools/update_oids.py new file mode 100755 index 0000000..df4f969 --- /dev/null +++ b/tools/update_oids.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +""" +Update the maps of builtin types and names. + +This script updates some of the files in psycopg source code with data read +from a database catalog. + +Hint: use docker to upgrade types from a new version in isolation. Run: + + docker run --rm -p 11111:5432 --name pg -e POSTGRES_PASSWORD=password postgres:TAG + +with a specified version tag, and then query it using: + + %(prog)s "host=localhost port=11111 user=postgres password=password" +""" + +import re +import argparse +import subprocess as sp +from typing import List +from pathlib import Path +from typing_extensions import TypeAlias + +import psycopg +from psycopg.rows import TupleRow +from psycopg.crdb import CrdbConnection + +Connection: TypeAlias = psycopg.Connection[TupleRow] + +ROOT = Path(__file__).parent.parent + + +def main() -> None: + opt = parse_cmdline() + conn = psycopg.connect(opt.dsn, autocommit=True) + + if CrdbConnection.is_crdb(conn): + conn = CrdbConnection.connect(opt.dsn, autocommit=True) + update_crdb_python_oids(conn) + else: + update_python_oids(conn) + update_cython_oids(conn) + + +def update_python_oids(conn: Connection) -> None: + fn = ROOT / "psycopg/psycopg/postgres.py" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_py_types(conn)) + lines.extend(get_py_ranges(conn)) + lines.extend(get_py_multiranges(conn)) + + update_file(fn, lines) + sp.check_call(["black", "-q", fn]) + + +def update_cython_oids(conn: Connection) -> None: + fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_cython_oids(conn)) + + update_file(fn, lines) + + +def update_crdb_python_oids(conn: Connection) -> None: + fn = ROOT / "psycopg/psycopg/crdb/_types.py" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_py_types(conn)) + + update_file(fn, lines) + sp.check_call(["black", "-q", fn]) + + +def get_version_comment(conn: Connection) -> List[str]: + if conn.info.vendor == "PostgreSQL": + # Assume PG > 10 + num = conn.info.server_version + version = f"{num // 10000}.{num % 100}" + elif conn.info.vendor == "CockroachDB": + assert isinstance(conn, CrdbConnection) + num = conn.info.server_version + version = f"{num // 10000}.{num % 10000 // 100}.{num % 100}" + else: + raise NotImplementedError(f"unexpected vendor: {conn.info.vendor}") + return ["", f" # Generated from {conn.info.vendor} {version}", ""] + + +def get_py_types(conn: Connection) -> List[str]: + # Note: "record" is a pseudotype but still a useful one to have. + # "pg_lsn" is a documented public type and useful in streaming replication + lines = [] + for (typname, oid, typarray, regtype, typdelim) in conn.execute( + """ +select typname, oid, typarray, + -- CRDB might have quotes in the regtype representation + replace(typname::regtype::text, '''', '') as regtype, + typdelim +from pg_type t +where + oid < 10000 + and oid != '"char"'::regtype + and (typtype = 'b' or typname = 'record') + and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') +order by typname +""" + ): + # Weird legacy type in postgres catalog + if typname == "char": + typname = regtype = '"char"' + + # https://github.com/cockroachdb/cockroach/issues/81645 + if typname == "int4" and conn.info.vendor == "CockroachDB": + regtype = typname + + params = [f"{typname!r}, {oid}, {typarray}"] + if regtype != typname: + params.append(f"regtype={regtype!r}") + if typdelim != ",": + params.append(f"delimiter={typdelim!r}") + lines.append(f"TypeInfo({','.join(params)}),") + + return lines + + +def get_py_ranges(conn: Connection) -> List[str]: + lines = [] + for (typname, oid, typarray, rngsubtype) in conn.execute( + """ +select typname, oid, typarray, rngsubtype +from + pg_type t + join pg_range r on t.oid = rngtypid +where + oid < 10000 + and typtype = 'r' +order by typname +""" + ): + params = [f"{typname!r}, {oid}, {typarray}, subtype_oid={rngsubtype}"] + lines.append(f"RangeInfo({','.join(params)}),") + + return lines + + +def get_py_multiranges(conn: Connection) -> List[str]: + lines = [] + for (typname, oid, typarray, rngtypid, rngsubtype) in conn.execute( + """ +select typname, oid, typarray, rngtypid, rngsubtype +from + pg_type t + join pg_range r on t.oid = rngmultitypid +where + oid < 10000 + and typtype = 'm' +order by typname +""" + ): + params = [ + f"{typname!r}, {oid}, {typarray}," + f" range_oid={rngtypid}, subtype_oid={rngsubtype}" + ] + lines.append(f"MultirangeInfo({','.join(params)}),") + + return lines + + +def get_cython_oids(conn: Connection) -> List[str]: + lines = [] + for (typname, oid) in conn.execute( + """ +select typname, oid +from pg_type +where + oid < 10000 + and (typtype = any('{b,r,m}') or typname = 'record') + and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') +order by typname +""" + ): + const_name = typname.upper() + "_OID" + lines.append(f" {const_name} = {oid}") + + return lines + + +def update_file(fn: Path, new: List[str]) -> None: + with fn.open("r") as f: + lines = f.read().splitlines() + istart, iend = [ + i + for i, line in enumerate(lines) + if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line) + ] + lines[istart + 1 : iend] = new + + with fn.open("w") as f: + f.write("\n".join(lines)) + f.write("\n") + + +def parse_cmdline() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("dsn", help="where to connect to") + opt = parser.parse_args() + return opt + + +if __name__ == "__main__": + main() |