summaryrefslogtreecommitdiffstats
path: root/sphinx/util
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/util')
-rw-r--r--sphinx/util/__init__.py26
-rw-r--r--sphinx/util/_io.py34
-rw-r--r--sphinx/util/_pathlib.py41
-rw-r--r--sphinx/util/build_phase.py1
-rw-r--r--sphinx/util/cfamily.py30
-rw-r--r--sphinx/util/console.py122
-rw-r--r--sphinx/util/display.py5
-rw-r--r--sphinx/util/docfields.py20
-rw-r--r--sphinx/util/docutils.py56
-rw-r--r--sphinx/util/exceptions.py5
-rw-r--r--sphinx/util/fileutil.py10
-rw-r--r--sphinx/util/http_date.py11
-rw-r--r--sphinx/util/i18n.py49
-rw-r--r--sphinx/util/inspect.py562
-rw-r--r--sphinx/util/inventory.py27
-rw-r--r--sphinx/util/logging.py88
-rw-r--r--sphinx/util/matching.py3
-rw-r--r--sphinx/util/math.py3
-rw-r--r--sphinx/util/nodes.py68
-rw-r--r--sphinx/util/osutil.py31
-rw-r--r--sphinx/util/parallel.py7
-rw-r--r--sphinx/util/requests.py13
-rw-r--r--sphinx/util/rst.py14
-rw-r--r--sphinx/util/tags.py6
-rw-r--r--sphinx/util/template.py11
-rw-r--r--sphinx/util/typing.py118
26 files changed, 882 insertions, 479 deletions
diff --git a/sphinx/util/__init__.py b/sphinx/util/__init__.py
index 69b2848..54ddc7e 100644
--- a/sphinx/util/__init__.py
+++ b/sphinx/util/__init__.py
@@ -20,8 +20,8 @@ from sphinx.util import index_entries as _index_entries
from sphinx.util import logging
from sphinx.util import osutil as _osutil
from sphinx.util.console import strip_colors # NoQA: F401
-from sphinx.util.matching import patfilter # noqa: F401
-from sphinx.util.nodes import ( # noqa: F401
+from sphinx.util.matching import patfilter # NoQA: F401
+from sphinx.util.nodes import ( # NoQA: F401
caption_ref_re,
explicit_title_re,
nested_parse_with_titles,
@@ -30,7 +30,7 @@ from sphinx.util.nodes import ( # noqa: F401
# import other utilities; partly for backwards compatibility, so don't
# prune unused ones indiscriminately
-from sphinx.util.osutil import ( # noqa: F401
+from sphinx.util.osutil import ( # NoQA: F401
SEP,
copyfile,
copytimes,
@@ -68,6 +68,7 @@ class FilenameUniqDict(dict):
interpreted as filenames, and keeps track of a set of docnames they
appear in. Used for images and downloadable files in the environment.
"""
+
def __init__(self) -> None:
self._existing: set[str] = set()
@@ -104,7 +105,7 @@ class FilenameUniqDict(dict):
self._existing = state
-def _md5(data=b'', **_kw):
+def _md5(data: bytes = b'', **_kw: Any) -> hashlib._Hash:
"""Deprecated wrapper around hashlib.md5
To be removed in Sphinx 9.0
@@ -112,7 +113,7 @@ def _md5(data=b'', **_kw):
return hashlib.md5(data, usedforsecurity=False)
-def _sha1(data=b'', **_kw):
+def _sha1(data: bytes = b'', **_kw: Any) -> hashlib._Hash:
"""Deprecated wrapper around hashlib.sha1
To be removed in Sphinx 9.0
@@ -178,6 +179,7 @@ class Tee:
"""
File-like object writing to two streams.
"""
+
def __init__(self, stream1: IO, stream2: IO) -> None:
self.stream1 = stream1
self.stream2 = stream2
@@ -202,7 +204,7 @@ def parselinenos(spec: str, total: int) -> list[int]:
for part in parts:
try:
begend = part.strip().split('-')
- if ['', ''] == begend:
+ if begend == ['', '']:
raise ValueError
if len(begend) == 1:
items.append(int(begend[0]) - 1)
@@ -256,7 +258,7 @@ def isurl(url: str) -> bool:
return bool(url) and '://' in url
-def _xml_name_checker():
+def _xml_name_checker() -> re.Pattern[str]:
# to prevent import cycles
from sphinx.builders.epub3 import _XML_NAME_PATTERN
@@ -264,7 +266,7 @@ def _xml_name_checker():
# deprecated name -> (object to return, canonical path or empty string)
-_DEPRECATED_OBJECTS = {
+_DEPRECATED_OBJECTS: dict[str, tuple[Any, str] | tuple[Any, str, tuple[int, int]]] = {
'path_stabilize': (_osutil.path_stabilize, 'sphinx.util.osutil.path_stabilize'),
'display_chunk': (_display.display_chunk, 'sphinx.util.display.display_chunk'),
'status_iterator': (_display.status_iterator, 'sphinx.util.display.status_iterator'),
@@ -285,13 +287,15 @@ _DEPRECATED_OBJECTS = {
}
-def __getattr__(name):
+def __getattr__(name: str) -> Any:
if name not in _DEPRECATED_OBJECTS:
msg = f'module {__name__!r} has no attribute {name!r}'
raise AttributeError(msg)
from sphinx.deprecation import _deprecation_warning
- deprecated_object, canonical_name = _DEPRECATED_OBJECTS[name]
- _deprecation_warning(__name__, name, canonical_name, remove=(8, 0))
+ info = _DEPRECATED_OBJECTS[name]
+ deprecated_object, canonical_name = info[:2]
+ remove = info[2] if len(info) == 3 else (8, 0)
+ _deprecation_warning(__name__, name, canonical_name, remove=remove)
return deprecated_object
diff --git a/sphinx/util/_io.py b/sphinx/util/_io.py
new file mode 100644
index 0000000..3689d9e
--- /dev/null
+++ b/sphinx/util/_io.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from sphinx.util.console import strip_escape_sequences
+
+if TYPE_CHECKING:
+ from typing import Protocol
+
+ class SupportsWrite(Protocol):
+ def write(self, text: str, /) -> int | None:
+ ...
+
+
+class TeeStripANSI:
+ """File-like object writing to two streams."""
+
+ def __init__(
+ self,
+ stream_term: SupportsWrite,
+ stream_file: SupportsWrite,
+ ) -> None:
+ self.stream_term = stream_term
+ self.stream_file = stream_file
+
+ def write(self, text: str, /) -> None:
+ self.stream_term.write(text)
+ self.stream_file.write(strip_escape_sequences(text))
+
+ def flush(self) -> None:
+ if hasattr(self.stream_term, 'flush'):
+ self.stream_term.flush()
+ if hasattr(self.stream_file, 'flush'):
+ self.stream_file.flush()
diff --git a/sphinx/util/_pathlib.py b/sphinx/util/_pathlib.py
index 59980e9..8bb1f31 100644
--- a/sphinx/util/_pathlib.py
+++ b/sphinx/util/_pathlib.py
@@ -5,6 +5,7 @@ from __future__ import annotations
import sys
import warnings
from pathlib import Path, PosixPath, PurePath, WindowsPath
+from typing import Any
from sphinx.deprecation import RemovedInSphinx80Warning
@@ -21,34 +22,36 @@ _MSG = (
if sys.platform == 'win32':
class _StrPath(WindowsPath):
- def replace(self, old, new, count=-1, /):
+ def replace( # type: ignore[override]
+ self, old: str, new: str, count: int = -1, /,
+ ) -> str:
# replace exists in both Path and str;
# in Path it makes filesystem changes, so we use the safer str version
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return self.__str__().replace(old, new, count)
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> Any:
if item in _STR_METHODS:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return getattr(self.__str__(), item)
msg = f'{_PATH_NAME!r} has no attribute {item!r}'
raise AttributeError(msg)
- def __add__(self, other):
+ def __add__(self, other: str) -> str:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return self.__str__() + other
- def __bool__(self):
+ def __bool__(self) -> bool:
if not self.__str__():
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return False
return True
- def __contains__(self, item):
+ def __contains__(self, item: str) -> bool:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return item in self.__str__()
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, PurePath):
return super().__eq__(other)
if isinstance(other, str):
@@ -56,46 +59,48 @@ if sys.platform == 'win32':
return self.__str__() == other
return NotImplemented
- def __hash__(self):
+ def __hash__(self) -> int:
return super().__hash__()
- def __getitem__(self, item):
+ def __getitem__(self, item: int | slice) -> str:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return self.__str__()[item]
- def __len__(self):
+ def __len__(self) -> int:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return len(self.__str__())
else:
class _StrPath(PosixPath):
- def replace(self, old, new, count=-1, /):
+ def replace( # type: ignore[override]
+ self, old: str, new: str, count: int = -1, /,
+ ) -> str:
# replace exists in both Path and str;
# in Path it makes filesystem changes, so we use the safer str version
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return self.__str__().replace(old, new, count)
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> Any:
if item in _STR_METHODS:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return getattr(self.__str__(), item)
msg = f'{_PATH_NAME!r} has no attribute {item!r}'
raise AttributeError(msg)
- def __add__(self, other):
+ def __add__(self, other: str) -> str:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return self.__str__() + other
- def __bool__(self):
+ def __bool__(self) -> bool:
if not self.__str__():
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return False
return True
- def __contains__(self, item):
+ def __contains__(self, item: str) -> bool:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return item in self.__str__()
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, PurePath):
return super().__eq__(other)
if isinstance(other, str):
@@ -103,13 +108,13 @@ else:
return self.__str__() == other
return NotImplemented
- def __hash__(self):
+ def __hash__(self) -> int:
return super().__hash__()
- def __getitem__(self, item):
+ def __getitem__(self, item: int | slice) -> str:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return self.__str__()[item]
- def __len__(self):
+ def __len__(self) -> int:
warnings.warn(_MSG, RemovedInSphinx80Warning, stacklevel=2)
return len(self.__str__())
diff --git a/sphinx/util/build_phase.py b/sphinx/util/build_phase.py
index 7f80aa5..76e94a9 100644
--- a/sphinx/util/build_phase.py
+++ b/sphinx/util/build_phase.py
@@ -5,6 +5,7 @@ from enum import IntEnum
class BuildPhase(IntEnum):
"""Build phase of Sphinx application."""
+
INITIALIZATION = 1
READING = 2
CONSISTENCY_CHECK = 3
diff --git a/sphinx/util/cfamily.py b/sphinx/util/cfamily.py
index a3fdbe3..c887983 100644
--- a/sphinx/util/cfamily.py
+++ b/sphinx/util/cfamily.py
@@ -12,6 +12,8 @@ from sphinx import addnodes
from sphinx.util import logging
if TYPE_CHECKING:
+ from collections.abc import Sequence
+
from docutils.nodes import TextElement
from sphinx.config import Config
@@ -86,7 +88,7 @@ class NoOldIdError(Exception):
class ASTBaseBase:
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
if type(self) is not type(other):
return False
try:
@@ -107,7 +109,7 @@ class ASTBaseBase:
raise NotImplementedError(repr(self))
def __str__(self) -> str:
- return self._stringify(lambda ast: str(ast))
+ return self._stringify(str)
def get_display_string(self) -> str:
return self._stringify(lambda ast: ast.get_display_string())
@@ -143,6 +145,11 @@ class ASTGnuAttribute(ASTBaseBase):
self.name = name
self.args = args
+ def __eq__(self, other: object) -> bool:
+ if type(other) is not ASTGnuAttribute:
+ return NotImplemented
+ return self.name == other.name and self.args == other.args
+
def _stringify(self, transform: StringifyTransform) -> str:
res = [self.name]
if self.args:
@@ -202,6 +209,11 @@ class ASTAttributeList(ASTBaseBase):
def __init__(self, attrs: list[ASTAttribute]) -> None:
self.attrs = attrs
+ def __eq__(self, other: object) -> bool:
+ if type(other) is not ASTAttributeList:
+ return NotImplemented
+ return self.attrs == other.attrs
+
def __len__(self) -> int:
return len(self.attrs)
@@ -265,14 +277,11 @@ class BaseParser:
for e in errors:
if len(e[1]) > 0:
indent = ' '
- result.append(e[1])
- result.append(':\n')
+ result.extend((e[1], ':\n'))
for line in str(e[0]).split('\n'):
if len(line) == 0:
continue
- result.append(indent)
- result.append(line)
- result.append('\n')
+ result.extend((indent, line, '\n'))
else:
result.append(str(e[0]))
return DefinitionError(''.join(result))
@@ -293,8 +302,7 @@ class BaseParser:
'Invalid %s declaration: %s [error at %d]\n %s\n %s' %
(self.language, msg, self.pos, self.definition, indicator))
errors.append((exMain, "Main error"))
- for err in self.otherErrors:
- errors.append((err, "Potential other error"))
+ errors.extend((err, "Potential other error") for err in self.otherErrors)
self.otherErrors = []
raise self._make_multi_error(errors, '')
@@ -369,11 +377,11 @@ class BaseParser:
################################################################################
@property
- def id_attributes(self):
+ def id_attributes(self) -> Sequence[str]:
raise NotImplementedError
@property
- def paren_attributes(self):
+ def paren_attributes(self) -> Sequence[str]:
raise NotImplementedError
def _parse_balanced_token_seq(self, end: list[str]) -> str:
diff --git a/sphinx/util/console.py b/sphinx/util/console.py
index 0fc9450..4257056 100644
--- a/sphinx/util/console.py
+++ b/sphinx/util/console.py
@@ -6,6 +6,37 @@ import os
import re
import shutil
import sys
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from typing import Final
+
+ # fmt: off
+ def reset(text: str) -> str: ... # NoQA: E704
+ def bold(text: str) -> str: ... # NoQA: E704
+ def faint(text: str) -> str: ... # NoQA: E704
+ def standout(text: str) -> str: ... # NoQA: E704
+ def underline(text: str) -> str: ... # NoQA: E704
+ def blink(text: str) -> str: ... # NoQA: E704
+
+ def black(text: str) -> str: ... # NoQA: E704
+ def white(text: str) -> str: ... # NoQA: E704
+ def red(text: str) -> str: ... # NoQA: E704
+ def green(text: str) -> str: ... # NoQA: E704
+ def yellow(text: str) -> str: ... # NoQA: E704
+ def blue(text: str) -> str: ... # NoQA: E704
+ def fuchsia(text: str) -> str: ... # NoQA: E704
+ def teal(text: str) -> str: ... # NoQA: E704
+
+ def darkgray(text: str) -> str: ... # NoQA: E704
+ def lightgray(text: str) -> str: ... # NoQA: E704
+ def darkred(text: str) -> str: ... # NoQA: E704
+ def darkgreen(text: str) -> str: ... # NoQA: E704
+ def brown(text: str) -> str: ... # NoQA: E704
+ def darkblue(text: str) -> str: ... # NoQA: E704
+ def purple(text: str) -> str: ... # NoQA: E704
+ def turquoise(text: str) -> str: ... # NoQA: E704
+ # fmt: on
try:
# check if colorama is installed to support color on Windows
@@ -13,8 +44,26 @@ try:
except ImportError:
colorama = None
+_CSI: Final[str] = re.escape('\x1b[') # 'ESC [': Control Sequence Introducer
+
+# Pattern matching ANSI control sequences containing colors.
+_ansi_color_re: Final[re.Pattern[str]] = re.compile(r'\x1b\[(?:\d+;){0,2}\d*m')
+
+_ansi_re: Final[re.Pattern[str]] = re.compile(
+ _CSI
+ + r"""
+ (?:
+ (?:\d+;){0,2}\d*m # ANSI color code ('m' is equivalent to '0m')
+ |
+ [012]?K # ANSI Erase in Line ('K' is equivalent to '0K')
+ )""",
+ re.VERBOSE | re.ASCII,
+)
+"""Pattern matching ANSI CSI colors (SGR) and erase line (EL) sequences.
+
+See :func:`strip_escape_sequences` for details.
+"""
-_ansi_re: re.Pattern[str] = re.compile('\x1b\\[(\\d\\d;){0,2}\\d\\dm')
codes: dict[str, str] = {}
@@ -37,7 +86,7 @@ def term_width_line(text: str) -> str:
return text + '\n'
else:
# codes are not displayed, this must be taken into account
- return text.ljust(_tw + len(text) - len(_ansi_re.sub('', text))) + '\r'
+ return text.ljust(_tw + len(text) - len(strip_escape_sequences(text))) + '\r'
def color_terminal() -> bool:
@@ -55,9 +104,7 @@ def color_terminal() -> bool:
if 'COLORTERM' in os.environ:
return True
term = os.environ.get('TERM', 'dumb').lower()
- if term in ('xterm', 'linux') or 'color' in term:
- return True
- return False
+ return term in ('xterm', 'linux') or 'color' in term
def nocolor() -> None:
@@ -87,41 +134,74 @@ def colorize(name: str, text: str, input_mode: bool = False) -> str:
def strip_colors(s: str) -> str:
- return re.compile('\x1b.*?m').sub('', s)
+ """Remove the ANSI color codes in a string *s*.
+
+ .. caution::
+
+ This function is not meant to be used in production and should only
+ be used for testing Sphinx's output messages.
+
+ .. seealso:: :func:`strip_escape_sequences`
+ """
+ return _ansi_color_re.sub('', s)
+
+
+def strip_escape_sequences(text: str, /) -> str:
+ r"""Remove the ANSI CSI colors and "erase in line" sequences.
+
+ Other `escape sequences `__ (e.g., VT100-specific functions) are not
+ supported and only control sequences *natively* known to Sphinx (i.e.,
+ colors declared in this module and "erase entire line" (``'\x1b[2K'``))
+ are eliminated by this function.
+
+ .. caution::
+
+ This function is not meant to be used in production and should only
+ be used for testing Sphinx's output messages that were not tempered
+ with by third-party extensions.
+
+ .. versionadded:: 7.3
+
+ This function is added as an *experimental* feature.
+
+ __ https://en.wikipedia.org/wiki/ANSI_escape_code
+ """
+ return _ansi_re.sub('', text)
def create_color_func(name: str) -> None:
def inner(text: str) -> str:
return colorize(name, text)
+
globals()[name] = inner
_attrs = {
- 'reset': '39;49;00m',
- 'bold': '01m',
- 'faint': '02m',
- 'standout': '03m',
+ 'reset': '39;49;00m',
+ 'bold': '01m',
+ 'faint': '02m',
+ 'standout': '03m',
'underline': '04m',
- 'blink': '05m',
+ 'blink': '05m',
}
-for _name, _value in _attrs.items():
- codes[_name] = '\x1b[' + _value
+for __name, __value in _attrs.items():
+ codes[__name] = '\x1b[' + __value
_colors = [
- ('black', 'darkgray'),
- ('darkred', 'red'),
+ ('black', 'darkgray'),
+ ('darkred', 'red'),
('darkgreen', 'green'),
- ('brown', 'yellow'),
- ('darkblue', 'blue'),
- ('purple', 'fuchsia'),
+ ('brown', 'yellow'),
+ ('darkblue', 'blue'),
+ ('purple', 'fuchsia'),
('turquoise', 'teal'),
('lightgray', 'white'),
]
-for i, (dark, light) in enumerate(_colors, 30):
- codes[dark] = '\x1b[%im' % i
- codes[light] = '\x1b[%im' % (i + 60)
+for __i, (__dark, __light) in enumerate(_colors, 30):
+ codes[__dark] = '\x1b[%im' % __i
+ codes[__light] = '\x1b[%im' % (__i + 60)
_orig_codes = codes.copy()
diff --git a/sphinx/util/display.py b/sphinx/util/display.py
index 199119c..3cb8d97 100644
--- a/sphinx/util/display.py
+++ b/sphinx/util/display.py
@@ -5,7 +5,7 @@ from typing import Any, Callable, TypeVar
from sphinx.locale import __
from sphinx.util import logging
-from sphinx.util.console import bold # type: ignore[attr-defined]
+from sphinx.util.console import bold, color_terminal
if False:
from collections.abc import Iterable, Iterator
@@ -33,7 +33,8 @@ def status_iterator(
verbosity: int = 0,
stringify_func: Callable[[Any], str] = display_chunk,
) -> Iterator[T]:
- single_line = verbosity < 1
+ # printing on a single line requires ANSI control sequences
+ single_line = verbosity < 1 and color_terminal()
bold_summary = bold(summary)
if length == 0:
logger.info(bold_summary, nonl=True)
diff --git a/sphinx/util/docfields.py b/sphinx/util/docfields.py
index c48c3be..c277a59 100644
--- a/sphinx/util/docfields.py
+++ b/sphinx/util/docfields.py
@@ -34,9 +34,7 @@ def _is_single_paragraph(node: nodes.field_body) -> bool:
for subnode in node[1:]: # type: Node
if not isinstance(subnode, nodes.system_message):
return False
- if isinstance(node[0], nodes.paragraph):
- return True
- return False
+ return isinstance(node[0], nodes.paragraph)
class Field:
@@ -52,6 +50,7 @@ class Field:
:returns: description of the return value
:rtype: description of the return type
"""
+
is_grouped = False
is_typed = False
@@ -79,7 +78,7 @@ class Field:
assert env is not None
assert (inliner is None) == (location is None), (inliner, location)
if not rolename:
- return contnode or innernode(target, target)
+ return contnode or innernode(target, target) # type: ignore[call-arg]
# The domain is passed from DocFieldTransformer. So it surely exists.
# So we don't need to take care the env.get_domain() raises an exception.
role = env.get_domain(domain).role(rolename)
@@ -90,7 +89,7 @@ class Field:
logger.warning(__(msg), domain, rolename, location=location)
refnode = addnodes.pending_xref('', refdomain=domain, refexplicit=False,
reftype=rolename, reftarget=target)
- refnode += contnode or innernode(target, target)
+ refnode += contnode or innernode(target, target) # type: ignore[call-arg]
env.get_domain(domain).process_field_xref(refnode)
return refnode
lineno = -1
@@ -152,6 +151,7 @@ class GroupedField(Field):
:raises ErrorClass: description when it is raised
"""
+
is_grouped = True
list_type = nodes.bullet_list
@@ -208,6 +208,7 @@ class TypedField(GroupedField):
:param SomeClass foo: description of parameter foo
"""
+
is_typed = True
def __init__(
@@ -233,7 +234,7 @@ class TypedField(GroupedField):
inliner: Inliner | None = None,
location: Element | None = None,
) -> nodes.field:
- def handle_item(fieldarg: str, content: str) -> nodes.paragraph:
+ def handle_item(fieldarg: str, content: list[Node]) -> nodes.paragraph:
par = nodes.paragraph()
par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
addnodes.literal_strong, env=env))
@@ -251,8 +252,10 @@ class TypedField(GroupedField):
else:
par += fieldtype
par += nodes.Text(')')
- par += nodes.Text(' -- ')
- par += content
+ has_content = any(c.astext().strip() for c in content)
+ if has_content:
+ par += nodes.Text(' -- ')
+ par += content
return par
fieldname = nodes.field_name('', self.label)
@@ -272,6 +275,7 @@ class DocFieldTransformer:
Transforms field lists in "doc field" syntax into better-looking
equivalents, using the field type definitions given on a domain.
"""
+
typemap: dict[str, tuple[Field, bool]]
def __init__(self, directive: ObjectDescription) -> None:
diff --git a/sphinx/util/docutils.py b/sphinx/util/docutils.py
index a862417..6a24d2e 100644
--- a/sphinx/util/docutils.py
+++ b/sphinx/util/docutils.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
report_re = re.compile('^(.+?:(?:\\d+)?): \\((DEBUG|INFO|WARNING|ERROR|SEVERE)/(\\d+)?\\) ')
if TYPE_CHECKING:
- from collections.abc import Generator
+ from collections.abc import Iterator
from types import ModuleType
from docutils.frontend import Values
@@ -38,29 +38,12 @@ if TYPE_CHECKING:
from sphinx.environment import BuildEnvironment
from sphinx.util.typing import RoleFunction
-# deprecated name -> (object to return, canonical path or empty string)
-_DEPRECATED_OBJECTS = {
- '__version_info__': (docutils.__version_info__, 'docutils.__version_info__'),
-}
-
-
-def __getattr__(name):
- if name not in _DEPRECATED_OBJECTS:
- msg = f'module {__name__!r} has no attribute {name!r}'
- raise AttributeError(msg)
-
- from sphinx.deprecation import _deprecation_warning
-
- deprecated_object, canonical_name = _DEPRECATED_OBJECTS[name]
- _deprecation_warning(__name__, name, canonical_name, remove=(7, 0))
- return deprecated_object
-
additional_nodes: set[type[Element]] = set()
@contextmanager
-def docutils_namespace() -> Generator[None, None, None]:
+def docutils_namespace() -> Iterator[None]:
"""Create namespace for reST parsers."""
try:
_directives = copy(directives._directives) # type: ignore[attr-defined]
@@ -101,7 +84,7 @@ def register_role(name: str, role: RoleFunction) -> None:
This modifies global state of docutils. So it is better to use this
inside ``docutils_namespace()`` to prevent side-effects.
"""
- roles.register_local_role(name, role)
+ roles.register_local_role(name, role) # type: ignore[arg-type]
def unregister_role(name: str) -> None:
@@ -138,7 +121,7 @@ def unregister_node(node: type[Element]) -> None:
@contextmanager
-def patched_get_language() -> Generator[None, None, None]:
+def patched_get_language() -> Iterator[None]:
"""Patch docutils.languages.get_language() temporarily.
This ignores the second argument ``reporter`` to suppress warnings.
@@ -150,7 +133,7 @@ def patched_get_language() -> Generator[None, None, None]:
return get_language(language_code)
try:
- docutils.languages.get_language = patched_get_language
+ docutils.languages.get_language = patched_get_language # type: ignore[assignment]
yield
finally:
# restore original implementations
@@ -158,7 +141,7 @@ def patched_get_language() -> Generator[None, None, None]:
@contextmanager
-def patched_rst_get_language() -> Generator[None, None, None]:
+def patched_rst_get_language() -> Iterator[None]:
"""Patch docutils.parsers.rst.languages.get_language().
Starting from docutils 0.17, get_language() in ``rst.languages``
also has a reporter, which needs to be disabled temporarily.
@@ -174,7 +157,7 @@ def patched_rst_get_language() -> Generator[None, None, None]:
return get_language(language_code)
try:
- docutils.parsers.rst.languages.get_language = patched_get_language
+ docutils.parsers.rst.languages.get_language = patched_get_language # type: ignore[assignment]
yield
finally:
# restore original implementations
@@ -182,7 +165,7 @@ def patched_rst_get_language() -> Generator[None, None, None]:
@contextmanager
-def using_user_docutils_conf(confdir: str | None) -> Generator[None, None, None]:
+def using_user_docutils_conf(confdir: str | None) -> Iterator[None]:
"""Let docutils know the location of ``docutils.conf`` for Sphinx."""
try:
docutilsconfig = os.environ.get('DOCUTILSCONFIG', None)
@@ -198,8 +181,8 @@ def using_user_docutils_conf(confdir: str | None) -> Generator[None, None, None]
@contextmanager
-def du19_footnotes() -> Generator[None, None, None]:
- def visit_footnote(self, node):
+def du19_footnotes() -> Iterator[None]:
+ def visit_footnote(self: HTMLTranslator, node: Element) -> None:
label_style = self.settings.footnote_references
if not isinstance(node.previous_sibling(), type(node)):
self.body.append(f'<aside class="footnote-list {label_style}">\n')
@@ -207,7 +190,7 @@ def du19_footnotes() -> Generator[None, None, None]:
classes=[node.tagname, label_style],
role="note"))
- def depart_footnote(self, node):
+ def depart_footnote(self: HTMLTranslator, node: Element) -> None:
self.body.append('</aside>\n')
if not isinstance(node.next_node(descend=False, siblings=True),
type(node)):
@@ -231,7 +214,7 @@ def du19_footnotes() -> Generator[None, None, None]:
@contextmanager
-def patch_docutils(confdir: str | None = None) -> Generator[None, None, None]:
+def patch_docutils(confdir: str | None = None) -> Iterator[None]:
"""Patch to docutils temporarily."""
with patched_get_language(), \
patched_rst_get_language(), \
@@ -263,8 +246,8 @@ class CustomReSTDispatcher:
self.directive_func = directives.directive
self.role_func = roles.role
- directives.directive = self.directive
- roles.role = self.role
+ directives.directive = self.directive # type: ignore[assignment]
+ roles.role = self.role # type: ignore[assignment]
def disable(self) -> None:
directives.directive = self.directive_func
@@ -290,6 +273,7 @@ class sphinx_domains(CustomReSTDispatcher):
"""Monkey-patch directive and role dispatch, so that domain-specific
markup takes precedence.
"""
+
def __init__(self, env: BuildEnvironment) -> None:
self.env = env
super().__init__()
@@ -354,7 +338,7 @@ class WarningStream:
class LoggingReporter(Reporter):
@classmethod
- def from_reporter(cls, reporter: Reporter) -> LoggingReporter:
+ def from_reporter(cls: type[LoggingReporter], reporter: Reporter) -> LoggingReporter:
"""Create an instance of LoggingReporter from other reporter object."""
return cls(reporter.source, reporter.report_level, reporter.halt_level,
reporter.debug_flag, reporter.error_handler)
@@ -375,16 +359,16 @@ class NullReporter(Reporter):
@contextmanager
-def switch_source_input(state: State, content: StringList) -> Generator[None, None, None]:
+def switch_source_input(state: State, content: StringList) -> Iterator[None]:
"""Switch current source input of state temporarily."""
try:
# remember the original ``get_source_and_line()`` method
gsal = state.memo.reporter.get_source_and_line # type: ignore[attr-defined]
# replace it by new one
- state_machine = StateMachine([], None) # type: ignore[arg-type]
+ state_machine: StateMachine[None] = StateMachine([], None) # type: ignore[arg-type]
state_machine.input_lines = content
- state.memo.reporter.get_source_and_line = state_machine.get_source_and_line # type: ignore[attr-defined] # noqa: E501
+ state.memo.reporter.get_source_and_line = state_machine.get_source_and_line # type: ignore[attr-defined] # NoQA: E501
yield
finally:
@@ -451,6 +435,7 @@ class SphinxRole:
.. note:: The subclasses of this class might not work with docutils.
This class is strongly coupled with Sphinx.
"""
+
name: str #: The role name actually used in the document.
rawtext: str #: A string containing the entire interpreted text input.
text: str #: The interpreted text content.
@@ -519,6 +504,7 @@ class ReferenceRole(SphinxRole):
the role. The parsed result; link title and target will be stored to
``self.title`` and ``self.target``.
"""
+
has_explicit_title: bool #: A boolean indicates the role has explicit title or not.
disabled: bool #: A boolean indicates the reference is disabled.
title: str #: The link title for the interpreted text.
diff --git a/sphinx/util/exceptions.py b/sphinx/util/exceptions.py
index 9e25695..577ec73 100644
--- a/sphinx/util/exceptions.py
+++ b/sphinx/util/exceptions.py
@@ -6,7 +6,7 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from sphinx.errors import SphinxParallelError
-from sphinx.util.console import strip_colors
+from sphinx.util.console import strip_escape_sequences
if TYPE_CHECKING:
from sphinx.application import Sphinx
@@ -31,7 +31,8 @@ def save_traceback(app: Sphinx | None, exc: BaseException) -> str:
last_msgs = exts_list = ''
else:
extensions = app.extensions.values()
- last_msgs = '\n'.join(f'# {strip_colors(s).strip()}' for s in app.messagelog)
+ last_msgs = '\n'.join(f'# {strip_escape_sequences(s).strip()}'
+ for s in app.messagelog)
exts_list = '\n'.join(f'# {ext.name} ({ext.version})' for ext in extensions
if ext.version != 'builtin')
diff --git a/sphinx/util/fileutil.py b/sphinx/util/fileutil.py
index 316ec39..e621f55 100644
--- a/sphinx/util/fileutil.py
+++ b/sphinx/util/fileutil.py
@@ -4,7 +4,7 @@ from __future__ import annotations
import os
import posixpath
-from typing import TYPE_CHECKING, Callable
+from typing import TYPE_CHECKING, Any, Callable
from docutils.utils import relative_path
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
def copy_asset_file(source: str | os.PathLike[str], destination: str | os.PathLike[str],
- context: dict | None = None,
+ context: dict[str, Any] | None = None,
renderer: BaseRenderer | None = None) -> None:
"""Copy an asset file to destination.
@@ -53,7 +53,7 @@ def copy_asset_file(source: str | os.PathLike[str], destination: str | os.PathLi
def copy_asset(source: str | os.PathLike[str], destination: str | os.PathLike[str],
excluded: PathMatcher = lambda path: False,
- context: dict | None = None, renderer: BaseRenderer | None = None,
+ context: dict[str, Any] | None = None, renderer: BaseRenderer | None = None,
onerror: Callable[[str, Exception], None] | None = None) -> None:
"""Copy asset files to destination recursively.
@@ -80,8 +80,8 @@ def copy_asset(source: str | os.PathLike[str], destination: str | os.PathLike[st
return
for root, dirs, files in os.walk(source, followlinks=True):
- reldir = relative_path(source, root) # type: ignore[arg-type]
- for dir in dirs[:]:
+ reldir = relative_path(source, root)
+ for dir in dirs.copy():
if excluded(posixpath.join(reldir, dir)):
dirs.remove(dir)
else:
diff --git a/sphinx/util/http_date.py b/sphinx/util/http_date.py
index 8e245cb..4908101 100644
--- a/sphinx/util/http_date.py
+++ b/sphinx/util/http_date.py
@@ -5,16 +5,23 @@ Reference: https://www.rfc-editor.org/rfc/rfc7231#section-7.1.1.1
import time
import warnings
-from email.utils import formatdate, parsedate_tz
+from email.utils import parsedate_tz
from sphinx.deprecation import RemovedInSphinx90Warning
+_WEEKDAY_NAME = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
+_MONTH_NAME = ('', # Placeholder for indexing purposes
+ 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
+ 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec')
_GMT_OFFSET = float(time.localtime().tm_gmtoff)
def epoch_to_rfc1123(epoch: float) -> str:
"""Return HTTP-date string from epoch offset."""
- return formatdate(epoch, usegmt=True)
+ yr, mn, dd, hh, mm, ss, wd, _yd, _tz = time.gmtime(epoch)
+ weekday_name = _WEEKDAY_NAME[wd]
+ month = _MONTH_NAME[mn]
+ return f'{weekday_name}, {dd:02} {month} {yr:04} {hh:02}:{mm:02}:{ss:02} GMT'
def rfc1123_to_epoch(rfc1123: str) -> float:
diff --git a/sphinx/util/i18n.py b/sphinx/util/i18n.py
index b820884..c14e3f0 100644
--- a/sphinx/util/i18n.py
+++ b/sphinx/util/i18n.py
@@ -6,7 +6,7 @@ import os
import re
from datetime import datetime, timezone
from os import path
-from typing import TYPE_CHECKING, Callable, NamedTuple
+from typing import TYPE_CHECKING, NamedTuple
import babel.dates
from babel.messages.mofile import write_mo
@@ -18,10 +18,41 @@ from sphinx.util import logging
from sphinx.util.osutil import SEP, canon_path, relpath
if TYPE_CHECKING:
- from collections.abc import Generator
+ import datetime as dt
+ from collections.abc import Iterator
+ from typing import Protocol, Union
+
+ from babel.core import Locale
from sphinx.environment import BuildEnvironment
+ class DateFormatter(Protocol):
+ def __call__( # NoQA: E704
+ self,
+ date: dt.date | None = ...,
+ format: str = ...,
+ locale: str | Locale | None = ...,
+ ) -> str: ...
+
+ class TimeFormatter(Protocol):
+ def __call__( # NoQA: E704
+ self,
+ time: dt.time | dt.datetime | float | None = ...,
+ format: str = ...,
+ tzinfo: dt.tzinfo | None = ...,
+ locale: str | Locale | None = ...,
+ ) -> str: ...
+
+ class DatetimeFormatter(Protocol):
+ def __call__( # NoQA: E704
+ self,
+ datetime: dt.date | dt.time | float | None = ...,
+ format: str = ...,
+ tzinfo: dt.tzinfo | None = ...,
+ locale: str | Locale | None = ...,
+ ) -> str: ...
+
+ Formatter = Union[DateFormatter, TimeFormatter, DatetimeFormatter]
logger = logging.getLogger(__name__)
@@ -81,7 +112,7 @@ class CatalogRepository:
self.encoding = encoding
@property
- def locale_dirs(self) -> Generator[str, None, None]:
+ def locale_dirs(self) -> Iterator[str]:
if not self.language:
return
@@ -94,14 +125,13 @@ class CatalogRepository:
logger.verbose(__('locale_dir %s does not exist'), locale_path)
@property
- def pofiles(self) -> Generator[tuple[str, str], None, None]:
+ def pofiles(self) -> Iterator[tuple[str, str]]:
for locale_dir in self.locale_dirs:
basedir = path.join(locale_dir, self.language, 'LC_MESSAGES')
for root, dirnames, filenames in os.walk(basedir):
# skip dot-directories
- for dirname in dirnames:
- if dirname.startswith('.'):
- dirnames.remove(dirname)
+ for dirname in [d for d in dirnames if d.startswith('.')]:
+ dirnames.remove(dirname)
for filename in filenames:
if filename.endswith('.po'):
@@ -109,7 +139,7 @@ class CatalogRepository:
yield basedir, relpath(fullpath, basedir)
@property
- def catalogs(self) -> Generator[CatalogInfo, None, None]:
+ def catalogs(self) -> Iterator[CatalogInfo]:
for basedir, filename in self.pofiles:
domain = canon_path(path.splitext(filename)[0])
yield CatalogInfo(basedir, domain, self.encoding)
@@ -170,7 +200,7 @@ date_format_re = re.compile('(%s)' % '|'.join(date_format_mappings))
def babel_format_date(date: datetime, format: str, locale: str,
- formatter: Callable = babel.dates.format_date) -> str:
+ formatter: Formatter = babel.dates.format_date) -> str:
# Check if we have the tzinfo attribute. If not we cannot do any time
# related formats.
if not hasattr(date, 'tzinfo'):
@@ -208,6 +238,7 @@ def format_date(
# Check if we have to use a different babel formatter then
# format_datetime, because we only want to format a date
# or a time.
+ function: Formatter
if token == '%x':
function = babel.dates.format_date
elif token == '%X':
diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py
index 7d7fbb8..6b13b29 100644
--- a/sphinx/util/inspect.py
+++ b/sphinx/util/inspect.py
@@ -11,41 +11,45 @@ import re
import sys
import types
import typing
-from collections.abc import Mapping, Sequence
+from collections.abc import Mapping
from functools import cached_property, partial, partialmethod, singledispatchmethod
from importlib import import_module
-from inspect import ( # noqa: F401
- Parameter,
- isasyncgenfunction,
- isclass,
- ismethod,
- ismethoddescriptor,
- ismodule,
-)
+from inspect import Parameter, Signature
from io import StringIO
-from types import (
- ClassMethodDescriptorType,
- MethodDescriptorType,
- MethodType,
- ModuleType,
- WrapperDescriptorType,
-)
-from typing import Any, Callable, cast
+from types import ClassMethodDescriptorType, MethodDescriptorType, WrapperDescriptorType
+from typing import TYPE_CHECKING, Any
from sphinx.pycode.ast import unparse as ast_unparse
from sphinx.util import logging
from sphinx.util.typing import ForwardRef, stringify_annotation
+if TYPE_CHECKING:
+ from collections.abc import Callable, Sequence
+ from inspect import _ParameterKind
+ from types import MethodType, ModuleType
+ from typing import Final
+
logger = logging.getLogger(__name__)
memory_address_re = re.compile(r' at 0x[0-9a-f]{8,16}(?=>)', re.IGNORECASE)
+# re-export as is
+isasyncgenfunction = inspect.isasyncgenfunction
+ismethod = inspect.ismethod
+ismethoddescriptor = inspect.ismethoddescriptor
+isclass = inspect.isclass
+ismodule = inspect.ismodule
+
def unwrap(obj: Any) -> Any:
- """Get an original object from wrapped object (wrapped functions)."""
+ """Get an original object from wrapped object (wrapped functions).
+
+ Mocked objects are returned as is.
+ """
if hasattr(obj, '__sphinx_mock__'):
# Skip unwrapping mock object to avoid RecursionError
return obj
+
try:
return inspect.unwrap(obj)
except ValueError:
@@ -53,14 +57,28 @@ def unwrap(obj: Any) -> Any:
return obj
-def unwrap_all(obj: Any, *, stop: Callable | None = None) -> Any:
- """
- Get an original object from wrapped object (unwrapping partials, wrapped
- functions, and other decorators).
+def unwrap_all(obj: Any, *, stop: Callable[[Any], bool] | None = None) -> Any:
+ """Get an original object from wrapped object.
+
+ Unlike :func:`unwrap`, this unwraps partial functions, wrapped functions,
+ class methods and static methods.
+
+ When specified, *stop* is a predicate indicating whether an object should
+ be unwrapped or not.
"""
+ if callable(stop):
+ while not stop(obj):
+ if ispartial(obj):
+ obj = obj.func
+ elif inspect.isroutine(obj) and hasattr(obj, '__wrapped__'):
+ obj = obj.__wrapped__
+ elif isclassmethod(obj) or isstaticmethod(obj):
+ obj = obj.__func__
+ else:
+ return obj
+ return obj # in case the while loop never starts
+
while True:
- if stop and stop(obj):
- return obj
if ispartial(obj):
obj = obj.func
elif inspect.isroutine(obj) and hasattr(obj, '__wrapped__'):
@@ -72,10 +90,11 @@ def unwrap_all(obj: Any, *, stop: Callable | None = None) -> Any:
def getall(obj: Any) -> Sequence[str] | None:
- """Get __all__ attribute of the module as dict.
+ """Get the ``__all__`` attribute of an object as sequence.
- Return None if given *obj* does not have __all__.
- Raises ValueError if given *obj* have invalid __all__.
+ This returns ``None`` if the given ``obj.__all__`` does not exist and
+ raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of
+ strings.
"""
__all__ = safe_getattr(obj, '__all__', None)
if __all__ is None:
@@ -86,35 +105,42 @@ def getall(obj: Any) -> Sequence[str] | None:
def getannotations(obj: Any) -> Mapping[str, Any]:
- """Get __annotations__ from given *obj* safely."""
- __annotations__ = safe_getattr(obj, '__annotations__', None)
+ """Safely get the ``__annotations__`` attribute of an object."""
+ if sys.version_info >= (3, 10, 0) or not isinstance(obj, type):
+ __annotations__ = safe_getattr(obj, '__annotations__', None)
+ else:
+ # Workaround for bugfix not available until python 3.10 as recommended by docs
+ # https://docs.python.org/3.10/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
+ __dict__ = safe_getattr(obj, '__dict__', {})
+ __annotations__ = __dict__.get('__annotations__', None)
if isinstance(__annotations__, Mapping):
return __annotations__
- else:
- return {}
+ return {}
def getglobals(obj: Any) -> Mapping[str, Any]:
- """Get __globals__ from given *obj* safely."""
+ """Safely get :attr:`obj.__globals__ <function.__globals__>`."""
__globals__ = safe_getattr(obj, '__globals__', None)
if isinstance(__globals__, Mapping):
return __globals__
- else:
- return {}
+ return {}
def getmro(obj: Any) -> tuple[type, ...]:
- """Get __mro__ from given *obj* safely."""
+ """Safely get :attr:`obj.__mro__ <class.__mro__>`."""
__mro__ = safe_getattr(obj, '__mro__', None)
if isinstance(__mro__, tuple):
return __mro__
- else:
- return ()
+ return ()
def getorigbases(obj: Any) -> tuple[Any, ...] | None:
- """Get __orig_bases__ from *obj* safely."""
- if not inspect.isclass(obj):
+ """Safely get ``obj.__orig_bases__``.
+
+ This returns ``None`` if the object is not a class or if ``__orig_bases__``
+ is not well-defined (e.g., a non-tuple object or an empty sequence).
+ """
+ if not isclass(obj):
return None
# Get __orig_bases__ from obj.__dict__ to avoid accessing the parent's __orig_bases__.
@@ -123,18 +149,17 @@ def getorigbases(obj: Any) -> tuple[Any, ...] | None:
__orig_bases__ = __dict__.get('__orig_bases__')
if isinstance(__orig_bases__, tuple) and len(__orig_bases__) > 0:
return __orig_bases__
- else:
- return None
+ return None
-def getslots(obj: Any) -> dict[str, Any] | None:
- """Get __slots__ attribute of the class as dict.
+def getslots(obj: Any) -> dict[str, Any] | dict[str, None] | None:
+ """Safely get :term:`obj.__slots__ <__slots__>` as a dictionary if any.
- Return None if gienv *obj* does not have __slots__.
- Raises TypeError if given *obj* is not a class.
- Raises ValueError if given *obj* have invalid __slots__.
+ - This returns ``None`` if ``obj.__slots__`` does not exist.
+ - This raises a :exc:`TypeError` if *obj* is not a class.
+ - This raises a :exc:`ValueError` if ``obj.__slots__`` is invalid.
"""
- if not inspect.isclass(obj):
+ if not isclass(obj):
raise TypeError
__slots__ = safe_getattr(obj, '__slots__', None)
@@ -151,7 +176,7 @@ def getslots(obj: Any) -> dict[str, Any] | None:
def isNewType(obj: Any) -> bool:
- """Check the if object is a kind of NewType."""
+ """Check the if object is a kind of :class:`~typing.NewType`."""
if sys.version_info[:2] >= (3, 10):
return isinstance(obj, typing.NewType)
__module__ = safe_getattr(obj, '__module__', None)
@@ -160,72 +185,71 @@ def isNewType(obj: Any) -> bool:
def isenumclass(x: Any) -> bool:
- """Check if the object is subclass of enum."""
- return inspect.isclass(x) and issubclass(x, enum.Enum)
+ """Check if the object is an :class:`enumeration class <enum.Enum>`."""
+ return isclass(x) and issubclass(x, enum.Enum)
def isenumattribute(x: Any) -> bool:
- """Check if the object is attribute of enum."""
+ """Check if the object is an enumeration attribute."""
return isinstance(x, enum.Enum)
def unpartial(obj: Any) -> Any:
- """Get an original object from partial object.
+ """Get an original object from a partial-like object.
- This returns given object itself if not partial.
+ If *obj* is not a partial object, it is returned as is.
+
+ .. seealso:: :func:`ispartial`
"""
while ispartial(obj):
obj = obj.func
-
return obj
def ispartial(obj: Any) -> bool:
- """Check if the object is partial."""
+ """Check if the object is a partial function or method."""
return isinstance(obj, (partial, partialmethod))
def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
- """Check if the object is classmethod."""
+ """Check if the object is a :class:`classmethod`."""
if isinstance(obj, classmethod):
return True
- if inspect.ismethod(obj) and obj.__self__ is not None and isclass(obj.__self__):
+ if ismethod(obj) and obj.__self__ is not None and isclass(obj.__self__):
return True
if cls and name:
- placeholder = object()
+ # trace __mro__ if the method is defined in parent class
+ sentinel = object()
for basecls in getmro(cls):
- meth = basecls.__dict__.get(name, placeholder)
- if meth is not placeholder:
+ meth = basecls.__dict__.get(name, sentinel)
+ if meth is not sentinel:
return isclassmethod(meth)
-
return False
def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
- """Check if the object is staticmethod."""
+ """Check if the object is a :class:`staticmethod`."""
if isinstance(obj, staticmethod):
return True
if cls and name:
# trace __mro__ if the method is defined in parent class
- #
- # .. note:: This only works well with new style classes.
+ sentinel = object()
for basecls in getattr(cls, '__mro__', [cls]):
- meth = basecls.__dict__.get(name)
- if meth:
+ meth = basecls.__dict__.get(name, sentinel)
+ if meth is not sentinel:
return isinstance(meth, staticmethod)
return False
def isdescriptor(x: Any) -> bool:
- """Check if the object is some kind of descriptor."""
+ """Check if the object is a :external+python:term:`descriptor`."""
return any(
- callable(safe_getattr(x, item, None))
- for item in ['__get__', '__set__', '__delete__']
+ callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__')
)
def isabstractmethod(obj: Any) -> bool:
- """Check if the object is an abstractmethod."""
+ """Check if the object is an :func:`abstractmethod`."""
return safe_getattr(obj, '__isabstractmethod__', False) is True
@@ -242,86 +266,106 @@ def is_cython_function_or_method(obj: Any) -> bool:
return False
+_DESCRIPTOR_LIKE: Final[tuple[type, ...]] = (
+ ClassMethodDescriptorType,
+ MethodDescriptorType,
+ WrapperDescriptorType,
+)
+
+
def isattributedescriptor(obj: Any) -> bool:
- """Check if the object is an attribute like descriptor."""
+ """Check if the object is an attribute-like descriptor."""
if inspect.isdatadescriptor(obj):
# data descriptor is kind of attribute
return True
if isdescriptor(obj):
# non data descriptor
unwrapped = unwrap(obj)
- if isfunction(unwrapped) or isbuiltin(unwrapped) or inspect.ismethod(unwrapped):
+ if isfunction(unwrapped) or isbuiltin(unwrapped) or ismethod(unwrapped):
# attribute must not be either function, builtin and method
return False
if is_cython_function_or_method(unwrapped):
# attribute must not be either function and method (for cython)
return False
- if inspect.isclass(unwrapped):
+ if isclass(unwrapped):
# attribute must not be a class
return False
- if isinstance(unwrapped, (ClassMethodDescriptorType,
- MethodDescriptorType,
- WrapperDescriptorType)):
+ if isinstance(unwrapped, _DESCRIPTOR_LIKE):
# attribute must not be a method descriptor
return False
- if type(unwrapped).__name__ == "instancemethod":
- # attribute must not be an instancemethod (C-API)
- return False
- return True
+ # attribute must not be an instancemethod (C-API)
+ return type(unwrapped).__name__ != 'instancemethod'
return False
def is_singledispatch_function(obj: Any) -> bool:
- """Check if the object is singledispatch function."""
- return (inspect.isfunction(obj) and
- hasattr(obj, 'dispatch') and
- hasattr(obj, 'register') and
- obj.dispatch.__module__ == 'functools')
+ """Check if the object is a :func:`~functools.singledispatch` function."""
+ return (
+ inspect.isfunction(obj)
+ and hasattr(obj, 'dispatch')
+ and hasattr(obj, 'register')
+ and obj.dispatch.__module__ == 'functools'
+ )
def is_singledispatch_method(obj: Any) -> bool:
- """Check if the object is singledispatch method."""
+ """Check if the object is a :class:`~functools.singledispatchmethod`."""
return isinstance(obj, singledispatchmethod)
def isfunction(obj: Any) -> bool:
- """Check if the object is function."""
+ """Check if the object is a user-defined function.
+
+ Partial objects are unwrapped before checking them.
+
+ .. seealso:: :external+python:func:`inspect.isfunction`
+ """
return inspect.isfunction(unpartial(obj))
def isbuiltin(obj: Any) -> bool:
- """Check if the object is function."""
+ """Check if the object is a built-in function or method.
+
+ Partial objects are unwrapped before checking them.
+
+ .. seealso:: :external+python:func:`inspect.isbuiltin`
+ """
return inspect.isbuiltin(unpartial(obj))
def isroutine(obj: Any) -> bool:
- """Check is any kind of function or method."""
+ """Check if the object is a kind of function or method.
+
+ Partial objects are unwrapped before checking them.
+
+ .. seealso:: :external+python:func:`inspect.isroutine`
+ """
return inspect.isroutine(unpartial(obj))
def iscoroutinefunction(obj: Any) -> bool:
- """Check if the object is coroutine-function."""
- def iswrappedcoroutine(obj: Any) -> bool:
- """Check if the object is wrapped coroutine-function."""
- if isstaticmethod(obj) or isclassmethod(obj) or ispartial(obj):
- # staticmethod, classmethod and partial method are not a wrapped coroutine-function
- # Note: Since 3.10, staticmethod and classmethod becomes a kind of wrappers
- return False
- return hasattr(obj, '__wrapped__')
-
- obj = unwrap_all(obj, stop=iswrappedcoroutine)
+ """Check if the object is a :external+python:term:`coroutine` function."""
+ obj = unwrap_all(obj, stop=_is_wrapped_coroutine)
return inspect.iscoroutinefunction(obj)
+def _is_wrapped_coroutine(obj: Any) -> bool:
+ """Check if the object is wrapped coroutine-function."""
+ if isstaticmethod(obj) or isclassmethod(obj) or ispartial(obj):
+ # staticmethod, classmethod and partial method are not a wrapped coroutine-function
+ # Note: Since 3.10, staticmethod and classmethod becomes a kind of wrappers
+ return False
+ return hasattr(obj, '__wrapped__')
+
+
def isproperty(obj: Any) -> bool:
- """Check if the object is property."""
+ """Check if the object is property (possibly cached)."""
return isinstance(obj, (property, cached_property))
def isgenericalias(obj: Any) -> bool:
- """Check if the object is GenericAlias."""
- return isinstance(
- obj, (types.GenericAlias, typing._BaseGenericAlias)) # type: ignore[attr-defined]
+ """Check if the object is a generic alias."""
+ return isinstance(obj, (types.GenericAlias, typing._BaseGenericAlias)) # type: ignore[attr-defined]
def safe_getattr(obj: Any, name: str, *defargs: Any) -> Any:
@@ -346,7 +390,7 @@ def safe_getattr(obj: Any, name: str, *defargs: Any) -> Any:
raise AttributeError(name) from exc
-def object_description(obj: Any, *, _seen: frozenset = frozenset()) -> str:
+def object_description(obj: Any, *, _seen: frozenset[int] = frozenset()) -> str:
"""A repr() implementation that returns text safe to use in reST context.
Maintains a set of 'seen' object IDs to detect and avoid infinite recursion.
@@ -362,8 +406,10 @@ def object_description(obj: Any, *, _seen: frozenset = frozenset()) -> str:
# Cannot sort dict keys, fall back to using descriptions as a sort key
sorted_keys = sorted(obj, key=lambda k: object_description(k, _seen=seen))
- items = ((object_description(key, _seen=seen),
- object_description(obj[key], _seen=seen)) for key in sorted_keys)
+ items = (
+ (object_description(key, _seen=seen), object_description(obj[key], _seen=seen))
+ for key in sorted_keys
+ )
return '{%s}' % ', '.join(f'{key}: {value}' for (key, value) in items)
elif isinstance(obj, set):
if id(obj) in seen:
@@ -384,15 +430,18 @@ def object_description(obj: Any, *, _seen: frozenset = frozenset()) -> str:
except TypeError:
# Cannot sort frozenset values, fall back to using descriptions as a sort key
sorted_values = sorted(obj, key=lambda x: object_description(x, _seen=seen))
- return 'frozenset({%s})' % ', '.join(object_description(x, _seen=seen)
- for x in sorted_values)
+ return 'frozenset({%s})' % ', '.join(
+ object_description(x, _seen=seen) for x in sorted_values
+ )
elif isinstance(obj, enum.Enum):
+ if obj.__repr__.__func__ is not enum.Enum.__repr__: # type: ignore[attr-defined]
+ return repr(obj)
return f'{obj.__class__.__name__}.{obj.name}'
elif isinstance(obj, tuple):
if id(obj) in seen:
return 'tuple(...)'
seen |= frozenset([id(obj)])
- return '(%s%s)' % (
+ return '({}{})'.format(
', '.join(object_description(x, _seen=seen) for x in obj),
',' * (len(obj) == 1),
)
@@ -413,16 +462,18 @@ def object_description(obj: Any, *, _seen: frozenset = frozenset()) -> str:
def is_builtin_class_method(obj: Any, attr_name: str) -> bool:
- """If attr_name is implemented at builtin class, return True.
+ """Check whether *attr_name* is implemented on a builtin class.
>>> is_builtin_class_method(int, '__init__')
True
- Why this function needed? CPython implements int.__init__ by Descriptor
- but PyPy implements it by pure Python code.
+
+ This function is needed since CPython implements ``int.__init__`` via
+ descriptors, but PyPy implementation is written in pure Python code.
"""
+ mro = getmro(obj)
+
try:
- mro = getmro(obj)
cls = next(c for c in mro if attr_name in safe_getattr(c, '__dict__', {}))
except StopIteration:
return False
@@ -449,10 +500,11 @@ class DefaultValue:
class TypeAliasForwardRef:
- """Pseudo typing class for autodoc_type_aliases.
+ """Pseudo typing class for :confval:`autodoc_type_aliases`.
- This avoids the error on evaluating the type inside `get_type_hints()`.
+ This avoids the error on evaluating the type inside :func:`typing.get_type_hints()`.
"""
+
def __init__(self, name: str) -> None:
self.name = name
@@ -471,9 +523,9 @@ class TypeAliasForwardRef:
class TypeAliasModule:
- """Pseudo module class for autodoc_type_aliases."""
+ """Pseudo module class for :confval:`autodoc_type_aliases`."""
- def __init__(self, modname: str, mapping: dict[str, str]) -> None:
+ def __init__(self, modname: str, mapping: Mapping[str, str]) -> None:
self.__modname = modname
self.__mapping = mapping
@@ -504,12 +556,13 @@ class TypeAliasModule:
class TypeAliasNamespace(dict[str, Any]):
- """Pseudo namespace class for autodoc_type_aliases.
+ """Pseudo namespace class for :confval:`autodoc_type_aliases`.
- This enables to look up nested modules and classes like `mod1.mod2.Class`.
+ Useful for looking up nested objects via ``namespace.foo.bar.Class``.
"""
- def __init__(self, mapping: dict[str, str]) -> None:
+ def __init__(self, mapping: Mapping[str, str]) -> None:
+ super().__init__()
self.__mapping = mapping
def __getitem__(self, key: str) -> Any:
@@ -526,19 +579,21 @@ class TypeAliasNamespace(dict[str, Any]):
raise KeyError
-def _should_unwrap(subject: Callable) -> bool:
+def _should_unwrap(subject: Callable[..., Any]) -> bool:
"""Check the function should be unwrapped on getting signature."""
__globals__ = getglobals(subject)
- if (__globals__.get('__name__') == 'contextlib' and
- __globals__.get('__file__') == contextlib.__file__):
- # contextmanger should be unwrapped
- return True
-
- return False
+ # contextmanger should be unwrapped
+ return (
+ __globals__.get('__name__') == 'contextlib'
+ and __globals__.get('__file__') == contextlib.__file__
+ )
-def signature(subject: Callable, bound_method: bool = False, type_aliases: dict | None = None,
- ) -> inspect.Signature:
+def signature(
+ subject: Callable[..., Any],
+ bound_method: bool = False,
+ type_aliases: Mapping[str, str] | None = None,
+) -> Signature:
"""Return a Signature object for the given *subject*.
:param bound_method: Specify *subject* is a bound method or not
@@ -591,37 +646,17 @@ def signature(subject: Callable, bound_method: bool = False, type_aliases: dict
#
# For example, this helps a function having a default value `inspect._empty`.
# refs: https://github.com/sphinx-doc/sphinx/issues/7935
- return inspect.Signature(parameters, return_annotation=return_annotation,
- __validate_parameters__=False)
+ return Signature(
+ parameters, return_annotation=return_annotation, __validate_parameters__=False
+ )
-def evaluate_signature(sig: inspect.Signature, globalns: dict | None = None,
- localns: dict | None = None,
- ) -> inspect.Signature:
+def evaluate_signature(
+ sig: Signature,
+ globalns: dict[str, Any] | None = None,
+ localns: dict[str, Any] | None = None,
+) -> Signature:
"""Evaluate unresolved type annotations in a signature object."""
- def evaluate_forwardref(ref: ForwardRef, globalns: dict, localns: dict) -> Any:
- """Evaluate a forward reference."""
- return ref._evaluate(globalns, localns, frozenset())
-
- def evaluate(annotation: Any, globalns: dict, localns: dict) -> Any:
- """Evaluate unresolved type annotation."""
- try:
- if isinstance(annotation, str):
- ref = ForwardRef(annotation, True)
- annotation = evaluate_forwardref(ref, globalns, localns)
-
- if isinstance(annotation, ForwardRef):
- annotation = evaluate_forwardref(ref, globalns, localns)
- elif isinstance(annotation, str):
- # might be a ForwardRef'ed annotation in overloaded functions
- ref = ForwardRef(annotation, True)
- annotation = evaluate_forwardref(ref, globalns, localns)
- except (NameError, TypeError):
- # failed to evaluate type. skipped.
- pass
-
- return annotation
-
if globalns is None:
globalns = {}
if localns is None:
@@ -630,20 +665,56 @@ def evaluate_signature(sig: inspect.Signature, globalns: dict | None = None,
parameters = list(sig.parameters.values())
for i, param in enumerate(parameters):
if param.annotation:
- annotation = evaluate(param.annotation, globalns, localns)
+ annotation = _evaluate(param.annotation, globalns, localns)
parameters[i] = param.replace(annotation=annotation)
return_annotation = sig.return_annotation
if return_annotation:
- return_annotation = evaluate(return_annotation, globalns, localns)
+ return_annotation = _evaluate(return_annotation, globalns, localns)
return sig.replace(parameters=parameters, return_annotation=return_annotation)
-def stringify_signature(sig: inspect.Signature, show_annotation: bool = True,
- show_return_annotation: bool = True,
- unqualified_typehints: bool = False) -> str:
- """Stringify a Signature object.
+def _evaluate_forwardref(
+ ref: ForwardRef,
+ globalns: dict[str, Any] | None,
+ localns: dict[str, Any] | None,
+) -> Any:
+ """Evaluate a forward reference."""
+ return ref._evaluate(globalns, localns, frozenset())
+
+
+def _evaluate(
+ annotation: Any,
+ globalns: dict[str, Any],
+ localns: dict[str, Any],
+) -> Any:
+ """Evaluate unresolved type annotation."""
+ try:
+ if isinstance(annotation, str):
+ ref = ForwardRef(annotation, True)
+ annotation = _evaluate_forwardref(ref, globalns, localns)
+
+ if isinstance(annotation, ForwardRef):
+ annotation = _evaluate_forwardref(ref, globalns, localns)
+ elif isinstance(annotation, str):
+ # might be a ForwardRef'ed annotation in overloaded functions
+ ref = ForwardRef(annotation, True)
+ annotation = _evaluate_forwardref(ref, globalns, localns)
+ except (NameError, TypeError):
+ # failed to evaluate type. skipped.
+ pass
+
+ return annotation
+
+
+def stringify_signature(
+ sig: Signature,
+ show_annotation: bool = True,
+ show_return_annotation: bool = True,
+ unqualified_typehints: bool = False,
+) -> str:
+ """Stringify a :class:`~inspect.Signature` object.
:param show_annotation: If enabled, show annotations on the signature
:param show_return_annotation: If enabled, show annotation of the return value
@@ -655,31 +726,35 @@ def stringify_signature(sig: inspect.Signature, show_annotation: bool = True,
else:
mode = 'fully-qualified'
+ EMPTY = Parameter.empty
+
args = []
last_kind = None
for param in sig.parameters.values():
- if param.kind != param.POSITIONAL_ONLY and last_kind == param.POSITIONAL_ONLY:
+ if param.kind != Parameter.POSITIONAL_ONLY and last_kind == Parameter.POSITIONAL_ONLY:
# PEP-570: Separator for Positional Only Parameter: /
args.append('/')
- if param.kind == param.KEYWORD_ONLY and last_kind in (param.POSITIONAL_OR_KEYWORD,
- param.POSITIONAL_ONLY,
- None):
+ if param.kind == Parameter.KEYWORD_ONLY and last_kind in (
+ Parameter.POSITIONAL_OR_KEYWORD,
+ Parameter.POSITIONAL_ONLY,
+ None,
+ ):
# PEP-3102: Separator for Keyword Only Parameter: *
args.append('*')
arg = StringIO()
- if param.kind == param.VAR_POSITIONAL:
+ if param.kind is Parameter.VAR_POSITIONAL:
arg.write('*' + param.name)
- elif param.kind == param.VAR_KEYWORD:
+ elif param.kind is Parameter.VAR_KEYWORD:
arg.write('**' + param.name)
else:
arg.write(param.name)
- if show_annotation and param.annotation is not param.empty:
+ if show_annotation and param.annotation is not EMPTY:
arg.write(': ')
arg.write(stringify_annotation(param.annotation, mode))
- if param.default is not param.empty:
- if show_annotation and param.annotation is not param.empty:
+ if param.default is not EMPTY:
+ if show_annotation and param.annotation is not EMPTY:
arg.write(' = ')
else:
arg.write('=')
@@ -688,91 +763,86 @@ def stringify_signature(sig: inspect.Signature, show_annotation: bool = True,
args.append(arg.getvalue())
last_kind = param.kind
- if last_kind == Parameter.POSITIONAL_ONLY:
+ if last_kind is Parameter.POSITIONAL_ONLY:
# PEP-570: Separator for Positional Only Parameter: /
args.append('/')
concatenated_args = ', '.join(args)
- if (sig.return_annotation is Parameter.empty or
- show_annotation is False or
- show_return_annotation is False):
+ if sig.return_annotation is EMPTY or not show_annotation or not show_return_annotation:
return f'({concatenated_args})'
else:
- annotation = stringify_annotation(sig.return_annotation, mode)
- return f'({concatenated_args}) -> {annotation}'
+ retann = stringify_annotation(sig.return_annotation, mode)
+ return f'({concatenated_args}) -> {retann}'
-def signature_from_str(signature: str) -> inspect.Signature:
- """Create a Signature object from string."""
+def signature_from_str(signature: str) -> Signature:
+ """Create a :class:`~inspect.Signature` object from a string."""
code = 'def func' + signature + ': pass'
module = ast.parse(code)
- function = cast(ast.FunctionDef, module.body[0])
+ function = typing.cast(ast.FunctionDef, module.body[0])
return signature_from_ast(function, code)
-def signature_from_ast(node: ast.FunctionDef, code: str = '') -> inspect.Signature:
- """Create a Signature object from AST *node*."""
- args = node.args
- defaults = list(args.defaults)
- params = []
- if hasattr(args, "posonlyargs"):
- posonlyargs = len(args.posonlyargs)
- positionals = posonlyargs + len(args.args)
- else:
- posonlyargs = 0
- positionals = len(args.args)
-
- for _ in range(len(defaults), positionals):
- defaults.insert(0, Parameter.empty) # type: ignore[arg-type]
-
- if hasattr(args, "posonlyargs"):
- for i, arg in enumerate(args.posonlyargs):
- if defaults[i] is Parameter.empty:
- default = Parameter.empty
- else:
- default = DefaultValue(
- ast_unparse(defaults[i], code)) # type: ignore[assignment]
-
- annotation = ast_unparse(arg.annotation, code) or Parameter.empty
- params.append(Parameter(arg.arg, Parameter.POSITIONAL_ONLY,
- default=default, annotation=annotation))
-
- for i, arg in enumerate(args.args):
- if defaults[i + posonlyargs] is Parameter.empty:
- default = Parameter.empty
- else:
- default = DefaultValue(
- ast_unparse(defaults[i + posonlyargs], code), # type: ignore[assignment]
- )
-
- annotation = ast_unparse(arg.annotation, code) or Parameter.empty
- params.append(Parameter(arg.arg, Parameter.POSITIONAL_OR_KEYWORD,
- default=default, annotation=annotation))
+def signature_from_ast(node: ast.FunctionDef, code: str = '') -> Signature:
+ """Create a :class:`~inspect.Signature` object from an AST node."""
+ EMPTY = Parameter.empty
+ args: ast.arguments = node.args
+ defaults: tuple[ast.expr | None, ...] = tuple(args.defaults)
+ pos_only_offset = len(args.posonlyargs)
+ defaults_offset = pos_only_offset + len(args.args) - len(defaults)
+ # The sequence ``D = args.defaults`` contains non-None AST expressions,
+ # so we can use ``None`` as a sentinel value for that to indicate that
+ # there is no default value for a specific parameter.
+ #
+ # Let *p* be the number of positional-only and positional-or-keyword
+ # arguments. Note that ``0 <= len(D) <= p`` and ``D[0]`` is the default
+ # value corresponding to a positional-only *or* a positional-or-keyword
+ # argument. Since a non-default argument cannot follow a default argument,
+ # the sequence *D* can be completed on the left by adding None sentinels
+ # so that ``len(D) == p`` and ``D[i]`` is the *i*-th default argument.
+ defaults = (None,) * defaults_offset + defaults
+
+ # construct the parameter list
+ params: list[Parameter] = []
+
+ # positional-only arguments (introduced in Python 3.8)
+ for arg, defexpr in zip(args.posonlyargs, defaults):
+ params.append(_define(Parameter.POSITIONAL_ONLY, arg, code, defexpr=defexpr))
+
+ # normal arguments
+ for arg, defexpr in zip(args.args, defaults[pos_only_offset:]):
+ params.append(_define(Parameter.POSITIONAL_OR_KEYWORD, arg, code, defexpr=defexpr))
+
+ # variadic positional argument (no possible default expression)
if args.vararg:
- annotation = ast_unparse(args.vararg.annotation, code) or Parameter.empty
- params.append(Parameter(args.vararg.arg, Parameter.VAR_POSITIONAL,
- annotation=annotation))
+ params.append(_define(Parameter.VAR_POSITIONAL, args.vararg, code, defexpr=None))
- for i, arg in enumerate(args.kwonlyargs):
- if args.kw_defaults[i] is None:
- default = Parameter.empty
- else:
- default = DefaultValue(
- ast_unparse(args.kw_defaults[i], code)) # type: ignore[arg-type,assignment]
- annotation = ast_unparse(arg.annotation, code) or Parameter.empty
- params.append(Parameter(arg.arg, Parameter.KEYWORD_ONLY, default=default,
- annotation=annotation))
+ # keyword-only arguments
+ for arg, defexpr in zip(args.kwonlyargs, args.kw_defaults):
+ params.append(_define(Parameter.KEYWORD_ONLY, arg, code, defexpr=defexpr))
+ # variadic keyword argument (no possible default expression)
if args.kwarg:
- annotation = ast_unparse(args.kwarg.annotation, code) or Parameter.empty
- params.append(Parameter(args.kwarg.arg, Parameter.VAR_KEYWORD,
- annotation=annotation))
+ params.append(_define(Parameter.VAR_KEYWORD, args.kwarg, code, defexpr=None))
+
+ return_annotation = ast_unparse(node.returns, code) or EMPTY
+ return Signature(params, return_annotation=return_annotation)
+
- return_annotation = ast_unparse(node.returns, code) or Parameter.empty
+def _define(
+ kind: _ParameterKind,
+ arg: ast.arg,
+ code: str,
+ *,
+ defexpr: ast.expr | None,
+) -> Parameter:
+ EMPTY = Parameter.empty
- return inspect.Signature(params, return_annotation=return_annotation)
+ default = EMPTY if defexpr is None else DefaultValue(ast_unparse(defexpr, code))
+ annotation = ast_unparse(arg.annotation, code) or EMPTY
+ return Parameter(arg.arg, kind, default=default, annotation=annotation)
def getdoc(
@@ -790,13 +860,6 @@ def getdoc(
* inherited docstring
* inherited decorated methods
"""
- def getdoc_internal(obj: Any, attrgetter: Callable = safe_getattr) -> str | None:
- doc = attrgetter(obj, '__doc__', None)
- if isinstance(doc, str):
- return doc
- else:
- return None
-
if cls and name and isclassmethod(obj, cls, name):
for basecls in getmro(cls):
meth = basecls.__dict__.get(name)
@@ -805,7 +868,7 @@ def getdoc(
if doc is not None or not allow_inherited:
return doc
- doc = getdoc_internal(obj)
+ doc = _getdoc_internal(obj)
if ispartial(obj) and doc == obj.__class__.__doc__:
return getdoc(obj.func)
elif doc is None and allow_inherited:
@@ -814,7 +877,7 @@ def getdoc(
for basecls in getmro(cls):
meth = safe_getattr(basecls, name, None)
if meth is not None:
- doc = getdoc_internal(meth)
+ doc = _getdoc_internal(meth)
if doc is not None:
break
@@ -831,3 +894,12 @@ def getdoc(
doc = inspect.getdoc(obj)
return doc
+
+
+def _getdoc_internal(
+ obj: Any, attrgetter: Callable[[Any, str, Any], Any] = safe_getattr
+) -> str | None:
+ doc = attrgetter(obj, '__doc__', None)
+ if isinstance(doc, str):
+ return doc
+ return None
diff --git a/sphinx/util/inventory.py b/sphinx/util/inventory.py
index 89f0070..a43fd03 100644
--- a/sphinx/util/inventory.py
+++ b/sphinx/util/inventory.py
@@ -25,7 +25,7 @@ class InventoryFileReader:
This reader supports mixture of texts and compressed texts.
"""
- def __init__(self, stream: IO) -> None:
+ def __init__(self, stream: IO[bytes]) -> None:
self.stream = stream
self.buffer = b''
self.eof = False
@@ -77,7 +77,12 @@ class InventoryFileReader:
class InventoryFile:
@classmethod
- def load(cls, stream: IO, uri: str, joinfunc: Callable) -> Inventory:
+ def load(
+ cls: type[InventoryFile],
+ stream: IO[bytes],
+ uri: str,
+ joinfunc: Callable[[str, str], str],
+ ) -> Inventory:
reader = InventoryFileReader(stream)
line = reader.readline().rstrip()
if line == '# Sphinx inventory version 1':
@@ -88,7 +93,12 @@ class InventoryFile:
raise ValueError('invalid inventory header: %s' % line)
@classmethod
- def load_v1(cls, stream: InventoryFileReader, uri: str, join: Callable) -> Inventory:
+ def load_v1(
+ cls: type[InventoryFile],
+ stream: InventoryFileReader,
+ uri: str,
+ join: Callable[[str, str], str],
+ ) -> Inventory:
invdata: Inventory = {}
projname = stream.readline().rstrip()[11:]
version = stream.readline().rstrip()[11:]
@@ -106,7 +116,12 @@ class InventoryFile:
return invdata
@classmethod
- def load_v2(cls, stream: InventoryFileReader, uri: str, join: Callable) -> Inventory:
+ def load_v2(
+ cls: type[InventoryFile],
+ stream: InventoryFileReader,
+ uri: str,
+ join: Callable[[str, str], str],
+ ) -> Inventory:
invdata: Inventory = {}
projname = stream.readline().rstrip()[11:]
version = stream.readline().rstrip()[11:]
@@ -140,7 +155,9 @@ class InventoryFile:
return invdata
@classmethod
- def dump(cls, filename: str, env: BuildEnvironment, builder: Builder) -> None:
+ def dump(
+ cls: type[InventoryFile], filename: str, env: BuildEnvironment, builder: Builder,
+ ) -> None:
def escape(string: str) -> str:
return re.sub("\\s+", " ", string)
diff --git a/sphinx/util/logging.py b/sphinx/util/logging.py
index 429018a..e107a56 100644
--- a/sphinx/util/logging.py
+++ b/sphinx/util/logging.py
@@ -16,7 +16,7 @@ from sphinx.util.console import colorize
from sphinx.util.osutil import abspath
if TYPE_CHECKING:
- from collections.abc import Generator
+ from collections.abc import Iterator
from docutils.nodes import Node
@@ -85,6 +85,7 @@ def convert_serializable(records: list[logging.LogRecord]) -> None:
class SphinxLogRecord(logging.LogRecord):
"""Log record class supporting location"""
+
prefix = ''
location: Any = None
@@ -101,11 +102,13 @@ class SphinxLogRecord(logging.LogRecord):
class SphinxInfoLogRecord(SphinxLogRecord):
"""Info log record class supporting location"""
+
prefix = '' # do not show any prefix for INFO messages
class SphinxWarningLogRecord(SphinxLogRecord):
"""Warning log record class supporting location"""
+
@property
def prefix(self) -> str: # type: ignore[override]
if self.levelno >= logging.CRITICAL:
@@ -118,6 +121,7 @@ class SphinxWarningLogRecord(SphinxLogRecord):
class SphinxLoggerAdapter(logging.LoggerAdapter):
"""LoggerAdapter allowing ``type`` and ``subtype`` keywords."""
+
KEYWORDS = ['type', 'subtype', 'location', 'nonl', 'color', 'once']
def log( # type: ignore[override]
@@ -143,9 +147,56 @@ class SphinxLoggerAdapter(logging.LoggerAdapter):
def handle(self, record: logging.LogRecord) -> None:
self.logger.handle(record)
+ def warning( # type: ignore[override]
+ self,
+ msg: object,
+ *args: object,
+ type: None | str = None,
+ subtype: None | str = None,
+ location: None | str | tuple[str | None, int | None] | Node = None,
+ nonl: bool = True,
+ color: str | None = None,
+ once: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Log a sphinx warning.
+
+ It is recommended to include a ``type`` and ``subtype`` for warnings as
+ these can be displayed to the user using :confval:`show_warning_types`
+ and used in :confval:`suppress_warnings` to suppress specific warnings.
+
+ It is also recommended to specify a ``location`` whenever possible
+ to help users in correcting the warning.
+
+ :param msg: The message, which may contain placeholders for ``args``.
+ :param args: The arguments to substitute into ``msg``.
+ :param type: The type of the warning.
+ :param subtype: The subtype of the warning.
+ :param location: The source location of the warning's origin,
+ which can be a string (the ``docname`` or ``docname:lineno``),
+ a tuple of ``(docname, lineno)``,
+ or the docutils node object.
+ :param nonl: Whether to append a new line terminator to the message.
+ :param color: A color code for the message.
+ :param once: Do not log this warning,
+ if a previous warning already has same ``msg``, ``args`` and ``once=True``.
+ """
+ return super().warning(
+ msg,
+ *args,
+ type=type,
+ subtype=subtype,
+ location=location,
+ nonl=nonl,
+ color=color,
+ once=once,
+ **kwargs,
+ )
+
class WarningStreamHandler(logging.StreamHandler):
"""StreamHandler for warnings."""
+
pass
@@ -195,7 +246,7 @@ class MemoryHandler(logging.handlers.BufferingHandler):
@contextmanager
-def pending_warnings() -> Generator[logging.Handler, None, None]:
+def pending_warnings() -> Iterator[logging.Handler]:
"""Context manager to postpone logging warnings temporarily.
Similar to :func:`pending_logging`.
@@ -223,7 +274,7 @@ def pending_warnings() -> Generator[logging.Handler, None, None]:
@contextmanager
-def suppress_logging() -> Generator[MemoryHandler, None, None]:
+def suppress_logging() -> Iterator[MemoryHandler]:
"""Context manager to suppress logging all logs temporarily.
For example::
@@ -252,7 +303,7 @@ def suppress_logging() -> Generator[MemoryHandler, None, None]:
@contextmanager
-def pending_logging() -> Generator[MemoryHandler, None, None]:
+def pending_logging() -> Iterator[MemoryHandler]:
"""Context manager to postpone logging all logs temporarily.
For example::
@@ -272,7 +323,7 @@ def pending_logging() -> Generator[MemoryHandler, None, None]:
@contextmanager
-def skip_warningiserror(skip: bool = True) -> Generator[None, None, None]:
+def skip_warningiserror(skip: bool = True) -> Iterator[None]:
"""Context manager to skip WarningIsErrorFilter temporarily."""
logger = logging.getLogger(NAMESPACE)
@@ -292,7 +343,7 @@ def skip_warningiserror(skip: bool = True) -> Generator[None, None, None]:
@contextmanager
-def prefixed_warnings(prefix: str) -> Generator[None, None, None]:
+def prefixed_warnings(prefix: str) -> Iterator[None]:
"""Context manager to prepend prefix to all warning log records temporarily.
For example::
@@ -342,7 +393,7 @@ class LogCollector:
self.logs: list[logging.LogRecord] = []
@contextmanager
- def collect(self) -> Generator[None, None, None]:
+ def collect(self) -> Iterator[None]:
with pending_logging() as memhandler:
yield
@@ -475,7 +526,9 @@ class SphinxLogRecordTranslator(logging.Filter):
* Make a instance of SphinxLogRecord
* docname to path if location given
+ * append warning type/subtype to message if :confval:`show_warning_types` is ``True``
"""
+
LogRecordClass: type[logging.LogRecord]
def __init__(self, app: Sphinx) -> None:
@@ -507,13 +560,32 @@ class SphinxLogRecordTranslator(logging.Filter):
class InfoLogRecordTranslator(SphinxLogRecordTranslator):
"""LogRecordTranslator for INFO level log records."""
+
LogRecordClass = SphinxInfoLogRecord
class WarningLogRecordTranslator(SphinxLogRecordTranslator):
"""LogRecordTranslator for WARNING level log records."""
+
LogRecordClass = SphinxWarningLogRecord
+ def filter(self, record: SphinxWarningLogRecord) -> bool: # type: ignore[override]
+ ret = super().filter(record)
+
+ try:
+ show_warning_types = self.app.config.show_warning_types
+ except AttributeError:
+ # config is not initialized yet (ex. in conf.py)
+ show_warning_types = False
+ if show_warning_types:
+ if log_type := getattr(record, 'type', ''):
+ if log_subtype := getattr(record, 'subtype', ''):
+ record.msg += f' [{log_type}.{log_subtype}]'
+ else:
+ record.msg += f' [{log_type}]'
+
+ return ret
+
def get_node_location(node: Node) -> str | None:
source, line = get_source_line(node)
@@ -543,6 +615,7 @@ class ColorizeFormatter(logging.Formatter):
class SafeEncodingWriter:
"""Stream writer which ignores UnicodeEncodeError silently"""
+
def __init__(self, stream: IO) -> None:
self.stream = stream
self.encoding = getattr(stream, 'encoding', 'ascii') or 'ascii'
@@ -562,6 +635,7 @@ class SafeEncodingWriter:
class LastMessagesWriter:
"""Stream writer storing last 10 messages in memory to save trackback"""
+
def __init__(self, app: Sphinx, stream: IO) -> None:
self.app = app
diff --git a/sphinx/util/matching.py b/sphinx/util/matching.py
index dd91905..481ca12 100644
--- a/sphinx/util/matching.py
+++ b/sphinx/util/matching.py
@@ -91,7 +91,8 @@ _pat_cache: dict[str, re.Pattern[str]] = {}
def patmatch(name: str, pat: str) -> re.Match[str] | None:
"""Return if name matches the regular expression (pattern)
- ``pat```. Adapted from fnmatch module."""
+ ``pat```. Adapted from fnmatch module.
+ """
if pat not in _pat_cache:
_pat_cache[pat] = re.compile(_translate_pattern(pat))
return _pat_cache[pat].match(name)
diff --git a/sphinx/util/math.py b/sphinx/util/math.py
index ef0eb39..97b8440 100644
--- a/sphinx/util/math.py
+++ b/sphinx/util/math.py
@@ -54,8 +54,7 @@ def wrap_displaymath(text: str, label: str | None, numbering: bool) -> str:
else:
begin = r'\begin{align*}%s\!\begin{aligned}' % labeldef
end = r'\end{aligned}\end{align*}'
- for part in parts:
- equations.append('%s\\\\\n' % part.strip())
+ equations.extend('%s\\\\\n' % part.strip() for part in parts)
concatenated_equations = ''.join(equations)
return f'{begin}\n{concatenated_equations}{end}'
diff --git a/sphinx/util/nodes.py b/sphinx/util/nodes.py
index b68b7fd..bbc1f64 100644
--- a/sphinx/util/nodes.py
+++ b/sphinx/util/nodes.py
@@ -5,18 +5,19 @@ from __future__ import annotations
import contextlib
import re
import unicodedata
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast
from docutils import nodes
+from docutils.nodes import Node
from sphinx import addnodes
from sphinx.locale import __
from sphinx.util import logging
if TYPE_CHECKING:
- from collections.abc import Iterable
+ from collections.abc import Iterable, Iterator
- from docutils.nodes import Element, Node
+ from docutils.nodes import Element
from docutils.parsers.rst import Directive
from docutils.parsers.rst.states import Inliner
from docutils.statemachine import StringList
@@ -33,7 +34,10 @@ explicit_title_re = re.compile(r'^(.+?)\s*(?<!\x00)<([^<]*?)>$', re.DOTALL)
caption_ref_re = explicit_title_re # b/w compat alias
-class NodeMatcher:
+N = TypeVar("N", bound=Node)
+
+
+class NodeMatcher(Generic[N]):
"""A helper class for Node.findall().
It checks that the given node is an instance of the specified node-classes and
@@ -43,20 +47,18 @@ class NodeMatcher:
and ``reftype`` attributes::
matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation')
- doctree.findall(matcher)
+ matcher.findall(doctree)
# => [<reference ...>, <reference ...>, ...]
A special value ``typing.Any`` matches any kind of node-attributes. For example,
following example searches ``reference`` node having ``refdomain`` attributes::
- from __future__ import annotations
-from typing import TYPE_CHECKING, Any
matcher = NodeMatcher(nodes.reference, refdomain=Any)
- doctree.findall(matcher)
+ matcher.findall(doctree)
# => [<reference ...>, <reference ...>, ...]
"""
- def __init__(self, *node_classes: type[Node], **attrs: Any) -> None:
+ def __init__(self, *node_classes: type[N], **attrs: Any) -> None:
self.classes = node_classes
self.attrs = attrs
@@ -85,6 +87,15 @@ from typing import TYPE_CHECKING, Any
def __call__(self, node: Node) -> bool:
return self.match(node)
+ def findall(self, node: Node) -> Iterator[N]:
+ """An alternative to `Node.findall` with improved type safety.
+
+ While the `NodeMatcher` object can be used as an argument to `Node.findall`, doing so
+ confounds type checkers' ability to determine the return type of the iterator.
+ """
+ for found in node.findall(self):
+ yield cast(N, found)
+
def get_full_module_name(node: Node) -> str:
"""
@@ -99,7 +110,7 @@ def get_full_module_name(node: Node) -> str:
def repr_domxml(node: Node, length: int = 80) -> str:
"""
return DOM XML representation of the specified node like:
- '<paragraph translatable="False"><inline classes="versionmodified">New in version...'
+ '<paragraph translatable="False"><inline classes="versionadded">Added in version...'
:param nodes.Node node: target node
:param int length:
@@ -127,7 +138,7 @@ def apply_source_workaround(node: Element) -> None:
get_full_module_name(node), repr_domxml(node))
definition_list_item = node.parent
node.source = definition_list_item.source
- node.line = definition_list_item.line - 1
+ node.line = definition_list_item.line - 1 # type: ignore[operator]
node.rawsource = node.astext() # set 'classifier1' (or 'classifier2')
elif isinstance(node, nodes.classifier) and not node.source:
# docutils-0.15 fills in rawsource attribute, but not in source.
@@ -220,16 +231,13 @@ def is_translatable(node: Node) -> bool:
return False
# <field_name>orphan</field_name>
# XXX ignore all metadata (== docinfo)
- if isinstance(node, nodes.field_name) and node.children[0] == 'orphan':
+ if isinstance(node, nodes.field_name) and (node.children[0] == 'orphan'):
logger.debug('[i18n] SKIP %r because orphan node: %s',
get_full_module_name(node), repr_domxml(node))
return False
return True
- if isinstance(node, nodes.meta): # type: ignore[attr-defined]
- return True
-
- return False
+ return isinstance(node, nodes.meta)
LITERAL_TYPE_NODES = (
@@ -245,10 +253,10 @@ IMAGE_TYPE_NODES = (
def extract_messages(doctree: Element) -> Iterable[tuple[Element, str]]:
"""Extract translatable messages from a document tree."""
- for node in doctree.findall(is_translatable): # type: Element
+ for node in doctree.findall(is_translatable):
if isinstance(node, addnodes.translatable):
for msg in node.extract_original_messages():
- yield node, msg
+ yield node, msg # type: ignore[misc]
continue
if isinstance(node, LITERAL_TYPE_NODES):
msg = node.rawsource
@@ -262,14 +270,14 @@ def extract_messages(doctree: Element) -> Iterable[tuple[Element, str]]:
msg = f'.. image:: {image_uri}'
else:
msg = ''
- elif isinstance(node, nodes.meta): # type: ignore[attr-defined]
+ elif isinstance(node, nodes.meta):
msg = node["content"]
else:
- msg = node.rawsource.replace('\n', ' ').strip()
+ msg = node.rawsource.replace('\n', ' ').strip() # type: ignore[attr-defined]
# XXX nodes rendering empty are likely a bug in sphinx.addnodes
if msg:
- yield node, msg
+ yield node, msg # type: ignore[misc]
def get_node_source(node: Element) -> str:
@@ -308,7 +316,7 @@ def traverse_translatable_index(
) -> Iterable[tuple[Element, list[tuple[str, str, str, str, str | None]]]]:
"""Traverse translatable index node from a document tree."""
matcher = NodeMatcher(addnodes.index, inline=False)
- for node in doctree.findall(matcher): # type: addnodes.index
+ for node in matcher.findall(doctree):
if 'raw_entries' in node:
entries = node['raw_entries']
else:
@@ -402,9 +410,14 @@ def process_index_entry(entry: str, targetid: str,
return indexentries
-def inline_all_toctrees(builder: Builder, docnameset: set[str], docname: str,
- tree: nodes.document, colorfunc: Callable, traversed: list[str],
- ) -> nodes.document:
+def inline_all_toctrees(
+ builder: Builder,
+ docnameset: set[str],
+ docname: str,
+ tree: nodes.document,
+ colorfunc: Callable[[str], str],
+ traversed: list[str],
+) -> nodes.document:
"""Inline all toctrees in the *tree*.
Record all docnames in *docnameset*, and output docnames with *colorfunc*.
@@ -599,10 +612,7 @@ def is_smartquotable(node: Node) -> bool:
if pnode.get('support_smartquotes', None) is False:
return False
- if getattr(node, 'support_smartquotes', None) is False:
- return False
-
- return True
+ return getattr(node, 'support_smartquotes', None) is not False
def process_only_nodes(document: Node, tags: Tags) -> None:
diff --git a/sphinx/util/osutil.py b/sphinx/util/osutil.py
index c6adbe4..97a298e 100644
--- a/sphinx/util/osutil.py
+++ b/sphinx/util/osutil.py
@@ -11,12 +11,14 @@ import sys
import unicodedata
from io import StringIO
from os import path
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from sphinx.deprecation import _deprecation_warning
if TYPE_CHECKING:
from collections.abc import Iterator
+ from types import TracebackType
+ from typing import Any
# SEP separates path elements in the canonical file names
#
@@ -36,7 +38,7 @@ def canon_path(native_path: str | os.PathLike[str], /) -> str:
def path_stabilize(filepath: str | os.PathLike[str], /) -> str:
- "Normalize path separator and unicode string"
+ """Normalize path separator and unicode string"""
new_path = canon_path(filepath)
return unicodedata.normalize('NFC', new_path)
@@ -88,7 +90,16 @@ def copytimes(source: str | os.PathLike[str], dest: str | os.PathLike[str]) -> N
def copyfile(source: str | os.PathLike[str], dest: str | os.PathLike[str]) -> None:
"""Copy a file and its modification times, if possible.
- Note: ``copyfile`` skips copying if the file has not been changed"""
+ :param source: An existing source to copy.
+ :param dest: The destination path.
+ :raise FileNotFoundError: The *source* does not exist.
+
+ .. note:: :func:`copyfile` is a no-op if *source* and *dest* are identical.
+ """
+ if not path.exists(source):
+ msg = f'{os.fsdecode(source)} does not exist'
+ raise FileNotFoundError(msg)
+
if not path.exists(dest) or not filecmp.cmp(source, dest):
shutil.copyfile(source, dest)
with contextlib.suppress(OSError):
@@ -131,15 +142,22 @@ abspath = path.abspath
class _chdir:
"""Remove this fall-back once support for Python 3.10 is removed."""
- def __init__(self, target_dir: str, /):
+
+ def __init__(self, target_dir: str, /) -> None:
self.path = target_dir
self._dirs: list[str] = []
- def __enter__(self):
+ def __enter__(self) -> None:
self._dirs.append(os.getcwd())
os.chdir(self.path)
- def __exit__(self, _exc_type, _exc_value, _traceback, /):
+ def __exit__(
+ self,
+ type: type[BaseException] | None,
+ value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> None:
os.chdir(self._dirs.pop())
@@ -163,6 +181,7 @@ class FileAvoidWrite:
Objects can be used as context managers.
"""
+
def __init__(self, path: str) -> None:
self._path = path
self._io: StringIO | None = None
diff --git a/sphinx/util/parallel.py b/sphinx/util/parallel.py
index 0afdff9..10f8c89 100644
--- a/sphinx/util/parallel.py
+++ b/sphinx/util/parallel.py
@@ -94,7 +94,12 @@ class ParallelTasks:
proc = context.Process(target=self._process, args=(psend, task_func, arg))
self._procs[tid] = proc
self._precvsWaiting[tid] = precv
- self._join_one()
+ try:
+ self._join_one()
+ except Exception:
+ # shutdown other child processes on failure
+ # (e.g. OSError: Failed to allocate memory)
+ self.terminate()
def join(self) -> None:
try:
diff --git a/sphinx/util/requests.py b/sphinx/util/requests.py
index ec3d8d2..4afbd37 100644
--- a/sphinx/util/requests.py
+++ b/sphinx/util/requests.py
@@ -30,17 +30,19 @@ def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool:
def get(url: str, **kwargs: Any) -> requests.Response:
- """Sends a GET request like requests.get().
+ """Sends a GET request like ``requests.get()``.
- This sets up User-Agent header and TLS verification automatically."""
+ This sets up User-Agent header and TLS verification automatically.
+ """
with _Session() as session:
return session.get(url, **kwargs)
def head(url: str, **kwargs: Any) -> requests.Response:
- """Sends a HEAD request like requests.head().
+ """Sends a HEAD request like ``requests.head()``.
- This sets up User-Agent header and TLS verification automatically."""
+ This sets up User-Agent header and TLS verification automatically.
+ """
with _Session() as session:
return session.head(url, **kwargs)
@@ -54,7 +56,8 @@ class _Session(requests.Session):
) -> requests.Response:
"""Sends a request with an HTTP verb and url.
- This sets up User-Agent header and TLS verification automatically."""
+ This sets up User-Agent header and TLS verification automatically.
+ """
headers = kwargs.setdefault('headers', {})
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
if _tls_info:
diff --git a/sphinx/util/rst.py b/sphinx/util/rst.py
index 1e8fd66..4e8fdee 100644
--- a/sphinx/util/rst.py
+++ b/sphinx/util/rst.py
@@ -5,11 +5,11 @@ from __future__ import annotations
import re
from collections import defaultdict
from contextlib import contextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
from unicodedata import east_asian_width
from docutils.parsers.rst import roles
-from docutils.parsers.rst.languages import en as english
+from docutils.parsers.rst.languages import en as english # type: ignore[attr-defined]
from docutils.parsers.rst.states import Body
from docutils.utils import Reporter
from jinja2 import Environment, pass_environment
@@ -18,7 +18,7 @@ from sphinx.locale import __
from sphinx.util import docutils, logging
if TYPE_CHECKING:
- from collections.abc import Generator
+ from collections.abc import Iterator
from docutils.statemachine import StringList
@@ -54,17 +54,18 @@ def textwidth(text: str, widechars: str = 'WF') -> int:
def heading(env: Environment, text: str, level: int = 1) -> str:
"""Create a heading for *level*."""
assert level <= 3
- width = textwidth(text, WIDECHARS[env.language])
+ # ``env.language`` is injected by ``sphinx.util.template.ReSTRenderer``
+ width = textwidth(text, WIDECHARS[env.language]) # type: ignore[attr-defined]
sectioning_char = SECTIONING_CHARS[level - 1]
return f'{text}\n{sectioning_char * width}'
@contextmanager
-def default_role(docname: str, name: str) -> Generator[None, None, None]:
+def default_role(docname: str, name: str) -> Iterator[None]:
if name:
dummy_reporter = Reporter('', 4, 4)
role_fn, _ = roles.role(name, english, 0, dummy_reporter)
- if role_fn: # type: ignore[truthy-function]
+ if role_fn:
docutils.register_role('', role_fn) # type: ignore[arg-type]
else:
logger.warning(__('default role %s not found'), name, location=docname)
@@ -102,6 +103,7 @@ def append_epilog(content: StringList, epilog: str) -> None:
if epilog:
if len(content) > 0:
source, lineno = content.info(-1)
+ lineno = cast(int, lineno) # lineno will never be None, since len(content) > 0
else:
source = '<generated>'
lineno = 0
diff --git a/sphinx/util/tags.py b/sphinx/util/tags.py
index 73e1a83..5d8d890 100644
--- a/sphinx/util/tags.py
+++ b/sphinx/util/tags.py
@@ -20,8 +20,8 @@ class BooleanParser(Parser):
Only allow condition exprs and/or/not operations.
"""
- def parse_compare(self) -> Node:
- node: Node
+ def parse_compare(self) -> nodes.Expr:
+ node: nodes.Expr
token = self.stream.current
if token.type == 'name':
if token.value in ('true', 'false', 'True', 'False'):
@@ -67,7 +67,7 @@ class Tags:
msg = 'chunk after expression'
raise ValueError(msg)
- def eval_node(node: Node) -> bool:
+ def eval_node(node: Node | None) -> bool:
if isinstance(node, nodes.CondExpr):
if eval_node(node.test):
return eval_node(node.expr1)
diff --git a/sphinx/util/template.py b/sphinx/util/template.py
index a16a7a1..0ddc29e 100644
--- a/sphinx/util/template.py
+++ b/sphinx/util/template.py
@@ -26,7 +26,8 @@ class BaseRenderer:
def __init__(self, loader: BaseLoader | None = None) -> None:
self.env = SandboxedEnvironment(loader=loader, extensions=['jinja2.ext.i18n'])
self.env.filters['repr'] = repr
- self.env.install_gettext_translations(get_translator())
+ # ``install_gettext_translations`` is injected by the ``jinja2.ext.i18n`` extension
+ self.env.install_gettext_translations(get_translator()) # type: ignore[attr-defined]
def render(self, template_name: str, context: dict[str, Any]) -> str:
return self.env.get_template(template_name).render(context)
@@ -47,7 +48,9 @@ class FileRenderer(BaseRenderer):
super().__init__(loader)
@classmethod
- def render_from_file(cls, filename: str, context: dict[str, Any]) -> str:
+ def render_from_file(
+ cls: type[FileRenderer], filename: str, context: dict[str, Any],
+ ) -> str:
dirname = os.path.dirname(filename)
basename = os.path.basename(filename)
return cls(dirname).render(basename, context)
@@ -60,7 +63,9 @@ class SphinxRenderer(FileRenderer):
super().__init__(template_path)
@classmethod
- def render_from_file(cls, filename: str, context: dict[str, Any]) -> str:
+ def render_from_file(
+ cls: type[FileRenderer], filename: str, context: dict[str, Any],
+ ) -> str:
return FileRenderer.render_from_file(filename, context)
diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py
index 171420d..4fbb592 100644
--- a/sphinx/util/typing.py
+++ b/sphinx/util/typing.py
@@ -3,11 +3,12 @@
from __future__ import annotations
import sys
+import types
import typing
from collections.abc import Sequence
+from contextvars import Context, ContextVar, Token
from struct import Struct
-from types import TracebackType
-from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypedDict, TypeVar, Union
from docutils import nodes
from docutils.parsers.rst.states import Inliner
@@ -15,22 +16,47 @@ from docutils.parsers.rst.states import Inliner
if TYPE_CHECKING:
import enum
-try:
- from types import UnionType # type: ignore[attr-defined] # python 3.10 or above
-except ImportError:
+ from sphinx.application import Sphinx
+
+if sys.version_info >= (3, 10):
+ from types import UnionType
+else:
UnionType = None
-# classes that have incorrect __module__
-INVALID_BUILTIN_CLASSES = {
+# classes that have an incorrect .__module__ attribute
+_INVALID_BUILTIN_CLASSES = {
+ Context: 'contextvars.Context', # Context.__module__ == '_contextvars'
+ ContextVar: 'contextvars.ContextVar', # ContextVar.__module__ == '_contextvars'
+ Token: 'contextvars.Token', # Token.__module__ == '_contextvars'
Struct: 'struct.Struct', # Struct.__module__ == '_struct'
- TracebackType: 'types.TracebackType', # TracebackType.__module__ == 'builtins'
+ # types in 'types' with <type>.__module__ == 'builtins':
+ types.AsyncGeneratorType: 'types.AsyncGeneratorType',
+ types.BuiltinFunctionType: 'types.BuiltinFunctionType',
+ types.BuiltinMethodType: 'types.BuiltinMethodType',
+ types.CellType: 'types.CellType',
+ types.ClassMethodDescriptorType: 'types.ClassMethodDescriptorType',
+ types.CodeType: 'types.CodeType',
+ types.CoroutineType: 'types.CoroutineType',
+ types.FrameType: 'types.FrameType',
+ types.FunctionType: 'types.FunctionType',
+ types.GeneratorType: 'types.GeneratorType',
+ types.GetSetDescriptorType: 'types.GetSetDescriptorType',
+ types.LambdaType: 'types.LambdaType',
+ types.MappingProxyType: 'types.MappingProxyType',
+ types.MemberDescriptorType: 'types.MemberDescriptorType',
+ types.MethodDescriptorType: 'types.MethodDescriptorType',
+ types.MethodType: 'types.MethodType',
+ types.MethodWrapperType: 'types.MethodWrapperType',
+ types.ModuleType: 'types.ModuleType',
+ types.TracebackType: 'types.TracebackType',
+ types.WrapperDescriptorType: 'types.WrapperDescriptorType',
}
def is_invalid_builtin_class(obj: Any) -> bool:
"""Check *obj* is an invalid built-in class."""
try:
- return obj in INVALID_BUILTIN_CLASSES
+ return obj in _INVALID_BUILTIN_CLASSES
except TypeError: # unhashable type
return False
@@ -64,6 +90,30 @@ InventoryItem = tuple[
Inventory = dict[str, dict[str, InventoryItem]]
+class ExtensionMetadata(TypedDict, total=False):
+ """The metadata returned by an extension's ``setup()`` function.
+
+ See :ref:`ext-metadata`.
+ """
+
+ version: str
+ """The extension version (default: ``'unknown version'``)."""
+ env_version: int
+ """An integer that identifies the version of env data added by the extension."""
+ parallel_read_safe: bool
+ """Indicate whether parallel reading of source files is supported
+ by the extension.
+ """
+ parallel_write_safe: bool
+ """Indicate whether parallel writing of output files is supported
+ by the extension (default: ``True``).
+ """
+
+
+if TYPE_CHECKING:
+ _ExtensionSetupFunc = Callable[[Sphinx], ExtensionMetadata]
+
+
def get_type_hints(
obj: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None,
) -> dict[str, Any]:
@@ -128,19 +178,15 @@ def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> st
elif ismock(cls):
return f':py:class:`{modprefix}{cls.__module__}.{cls.__name__}`'
elif is_invalid_builtin_class(cls):
- return f':py:class:`{modprefix}{INVALID_BUILTIN_CLASSES[cls]}`'
+ return f':py:class:`{modprefix}{_INVALID_BUILTIN_CLASSES[cls]}`'
elif inspect.isNewType(cls):
if sys.version_info[:2] >= (3, 10):
# newtypes have correct module info since Python 3.10+
return f':py:class:`{modprefix}{cls.__module__}.{cls.__name__}`'
else:
- return ':py:class:`%s`' % cls.__name__
+ return f':py:class:`{cls.__name__}`'
elif UnionType and isinstance(cls, UnionType):
- if len(cls.__args__) > 1 and None in cls.__args__:
- args = ' | '.join(restify(a, mode) for a in cls.__args__ if a)
- return 'Optional[%s]' % args
- else:
- return ' | '.join(restify(a, mode) for a in cls.__args__)
+ return ' | '.join(restify(a, mode) for a in cls.__args__)
elif cls.__module__ in ('__builtin__', 'builtins'):
if hasattr(cls, '__args__'):
if not cls.__args__: # Empty tuple, list, ...
@@ -149,23 +195,11 @@ def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> st
concatenated_args = ', '.join(restify(arg, mode) for arg in cls.__args__)
return fr':py:class:`{cls.__name__}`\ [{concatenated_args}]'
else:
- return ':py:class:`%s`' % cls.__name__
+ return f':py:class:`{cls.__name__}`'
elif (inspect.isgenericalias(cls)
and cls.__module__ == 'typing'
and cls.__origin__ is Union): # type: ignore[attr-defined]
- if (len(cls.__args__) > 1 # type: ignore[attr-defined]
- and cls.__args__[-1] is NoneType): # type: ignore[attr-defined]
- if len(cls.__args__) > 2: # type: ignore[attr-defined]
- args = ', '.join(restify(a, mode)
- for a in cls.__args__[:-1]) # type: ignore[attr-defined]
- return ':py:obj:`~typing.Optional`\\ [:obj:`~typing.Union`\\ [%s]]' % args
- else:
- return ':py:obj:`~typing.Optional`\\ [%s]' % restify(
- cls.__args__[0], mode) # type: ignore[attr-defined]
- else:
- args = ', '.join(restify(a, mode)
- for a in cls.__args__) # type: ignore[attr-defined]
- return ':py:obj:`~typing.Union`\\ [%s]' % args
+ return ' | '.join(restify(a, mode) for a in cls.__args__) # type: ignore[attr-defined]
elif inspect.isgenericalias(cls):
if isinstance(cls.__origin__, typing._SpecialForm): # type: ignore[attr-defined]
text = restify(cls.__origin__, mode) # type: ignore[attr-defined,arg-type]
@@ -195,14 +229,14 @@ def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> st
literal_args.append(_format_literal_enum_arg(a, mode=mode))
else:
literal_args.append(repr(a))
- text += r"\ [%s]" % ', '.join(literal_args)
+ text += fr"\ [{', '.join(literal_args)}]"
del literal_args
elif cls.__args__:
- text += r"\ [%s]" % ", ".join(restify(a, mode) for a in cls.__args__)
+ text += fr"\ [{', '.join(restify(a, mode) for a in cls.__args__)}]"
return text
elif isinstance(cls, typing._SpecialForm):
- return f':py:obj:`~{cls.__module__}.{cls._name}`' # type: ignore[attr-defined]
+ return f':py:obj:`~{cls.__module__}.{cls._name}`'
elif sys.version_info[:2] >= (3, 11) and cls is typing.Any:
# handle bpo-46998
return f':py:obj:`~{cls.__module__}.{cls.__name__}`'
@@ -212,7 +246,7 @@ def restify(cls: type | None, mode: str = 'fully-qualified-except-typing') -> st
else:
return f':py:class:`{modprefix}{cls.__module__}.{cls.__qualname__}`'
elif isinstance(cls, ForwardRef):
- return ':py:class:`%s`' % cls.__forward_arg__
+ return f':py:class:`{cls.__forward_arg__}`'
else:
# not a class (ex. TypeVar)
if cls.__module__ == 'typing':
@@ -285,7 +319,7 @@ def stringify_annotation(
elif ismock(annotation):
return module_prefix + f'{annotation_module}.{annotation_name}'
elif is_invalid_builtin_class(annotation):
- return module_prefix + INVALID_BUILTIN_CLASSES[annotation]
+ return module_prefix + _INVALID_BUILTIN_CLASSES[annotation]
elif str(annotation).startswith('typing.Annotated'): # for py310+
pass
elif annotation_module == 'builtins' and annotation_qualname:
@@ -350,7 +384,7 @@ def stringify_annotation(
elif qualname == 'Literal':
from sphinx.util.inspect import isenumattribute # lazy loading
- def format_literal_arg(arg):
+ def format_literal_arg(arg: Any) -> str:
if isenumattribute(arg):
enumcls = arg.__class__
@@ -384,19 +418,19 @@ def _format_literal_enum_arg(arg: enum.Enum, /, *, mode: str) -> str:
return f':py:attr:`{enum_cls.__module__}.{enum_cls.__qualname__}.{arg.name}`'
-# deprecated name -> (object to return, canonical path or empty string)
-_DEPRECATED_OBJECTS = {
- 'stringify': (stringify_annotation, 'sphinx.util.typing.stringify_annotation'),
+# deprecated name -> (object to return, canonical path or empty string, removal version)
+_DEPRECATED_OBJECTS: dict[str, tuple[Any, str, tuple[int, int]]] = {
+ 'stringify': (stringify_annotation, 'sphinx.util.typing.stringify_annotation', (8, 0)),
}
-def __getattr__(name):
+def __getattr__(name: str) -> Any:
if name not in _DEPRECATED_OBJECTS:
msg = f'module {__name__!r} has no attribute {name!r}'
raise AttributeError(msg)
from sphinx.deprecation import _deprecation_warning
- deprecated_object, canonical_name = _DEPRECATED_OBJECTS[name]
- _deprecation_warning(__name__, name, canonical_name, remove=(8, 0))
+ deprecated_object, canonical_name, remove = _DEPRECATED_OBJECTS[name]
+ _deprecation_warning(__name__, name, canonical_name, remove=remove)
return deprecated_object