summaryrefslogtreecommitdiffstats
path: root/ptpython/repl.py
diff options
context:
space:
mode:
Diffstat (limited to 'ptpython/repl.py')
-rw-r--r--ptpython/repl.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/ptpython/repl.py b/ptpython/repl.py
index fc9b9da..ea2d84f 100644
--- a/ptpython/repl.py
+++ b/ptpython/repl.py
@@ -7,6 +7,7 @@ Utility for creating a Python repl.
embed(globals(), locals(), vi_mode=False)
"""
+
from __future__ import annotations
import asyncio
@@ -18,7 +19,8 @@ import traceback
import types
import warnings
from dis import COMPILER_FLAG_NAMES
-from typing import Any, Callable, ContextManager, Iterable
+from pathlib import Path
+from typing import Any, Callable, ContextManager, Iterable, Sequence
from prompt_toolkit.formatted_text import OneStyleAndTextTuple
from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context
@@ -63,7 +65,7 @@ def _has_coroutine_flag(code: types.CodeType) -> bool:
class PythonRepl(PythonInput):
def __init__(self, *a, **kw) -> None:
- self._startup_paths = kw.pop("startup_paths", None)
+ self._startup_paths: Sequence[str | Path] | None = kw.pop("startup_paths", None)
super().__init__(*a, **kw)
self._load_start_paths()
@@ -347,7 +349,7 @@ class PythonRepl(PythonInput):
def get_compiler_flags(self) -> int:
return super().get_compiler_flags() | PyCF_ALLOW_TOP_LEVEL_AWAIT
- def _compile_with_flags(self, code: str, mode: str):
+ def _compile_with_flags(self, code: str, mode: str) -> Any:
"Compile code with the right compiler flags."
return compile(
code,
@@ -458,13 +460,13 @@ def run_config(repl: PythonInput, config_file: str | None = None) -> None:
def embed(
- globals=None,
- locals=None,
+ globals: dict[str, Any] | None = None,
+ locals: dict[str, Any] | None = None,
configure: Callable[[PythonRepl], None] | None = None,
vi_mode: bool = False,
history_filename: str | None = None,
title: str | None = None,
- startup_paths=None,
+ startup_paths: Sequence[str | Path] | None = None,
patch_stdout: bool = False,
return_asyncio_coroutine: bool = False,
) -> None:
@@ -493,10 +495,10 @@ def embed(
locals = locals or globals
- def get_globals():
+ def get_globals() -> dict[str, Any]:
return globals
- def get_locals():
+ def get_locals() -> dict[str, Any]:
return locals
# Create REPL.