summaryrefslogtreecommitdiffstats
path: root/ptpython/repl.py
diff options
context:
space:
mode:
Diffstat (limited to 'ptpython/repl.py')
-rw-r--r--ptpython/repl.py608
1 files changed, 187 insertions, 421 deletions
diff --git a/ptpython/repl.py b/ptpython/repl.py
index ae7b1d0..fc9b9da 100644
--- a/ptpython/repl.py
+++ b/ptpython/repl.py
@@ -7,44 +7,32 @@ Utility for creating a Python repl.
embed(globals(), locals(), vi_mode=False)
"""
+from __future__ import annotations
+
import asyncio
import builtins
import os
+import signal
import sys
-import threading
import traceback
import types
import warnings
from dis import COMPILER_FLAG_NAMES
-from enum import Enum
-from typing import Any, Callable, ContextManager, Dict, Optional
-
-from prompt_toolkit.formatted_text import (
- HTML,
- AnyFormattedText,
- FormattedText,
- PygmentsTokens,
- StyleAndTextTuples,
- fragment_list_width,
- merge_formatted_text,
- to_formatted_text,
-)
-from prompt_toolkit.formatted_text.utils import fragment_list_to_text, split_lines
-from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
+from typing import Any, Callable, ContextManager, Iterable
+
+from prompt_toolkit.formatted_text import OneStyleAndTextTuple
from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context
from prompt_toolkit.shortcuts import (
- PromptSession,
clear_title,
- print_formatted_text,
set_title,
)
-from prompt_toolkit.styles import BaseStyle
-from prompt_toolkit.utils import DummyContext, get_cwidth
-from pygments.lexers import PythonLexer, PythonTracebackLexer
-from pygments.token import Token
+from prompt_toolkit.utils import DummyContext
+from pygments.lexers import PythonTracebackLexer # noqa: F401
+from .printer import OutputPrinter
from .python_input import PythonInput
+PyCF_ALLOW_TOP_LEVEL_AWAIT: int
try:
from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT # type: ignore
except ImportError:
@@ -53,7 +41,7 @@ except ImportError:
__all__ = ["PythonRepl", "enable_deprecation_warnings", "run_config", "embed"]
-def _get_coroutine_flag() -> Optional[int]:
+def _get_coroutine_flag() -> int | None:
for k, v in COMPILER_FLAG_NAMES.items():
if v == "COROUTINE":
return k
@@ -62,7 +50,7 @@ def _get_coroutine_flag() -> Optional[int]:
return None
-COROUTINE_FLAG: Optional[int] = _get_coroutine_flag()
+COROUTINE_FLAG: int | None = _get_coroutine_flag()
def _has_coroutine_flag(code: types.CodeType) -> bool:
@@ -80,7 +68,7 @@ class PythonRepl(PythonInput):
self._load_start_paths()
def _load_start_paths(self) -> None:
- " Start the Read-Eval-Print Loop. "
+ "Start the Read-Eval-Print Loop."
if self._startup_paths:
for path in self._startup_paths:
if os.path.exists(path):
@@ -89,7 +77,57 @@ class PythonRepl(PythonInput):
exec(code, self.get_globals(), self.get_locals())
else:
output = self.app.output
- output.write("WARNING | File not found: {}\n\n".format(path))
+ output.write(f"WARNING | File not found: {path}\n\n")
+
+ def run_and_show_expression(self, expression: str) -> None:
+ try:
+ # Eval.
+ try:
+ result = self.eval(expression)
+ except KeyboardInterrupt:
+ # KeyboardInterrupt doesn't inherit from Exception.
+ raise
+ except SystemExit:
+ raise
+ except BaseException as e:
+ self._handle_exception(e)
+ else:
+ # Print.
+ if result is not None:
+ self._show_result(result)
+ if self.insert_blank_line_after_output:
+ self.app.output.write("\n")
+
+ # Loop.
+ self.current_statement_index += 1
+ self.signatures = []
+
+ except KeyboardInterrupt as e:
+ # Handle all possible `KeyboardInterrupt` errors. This can
+ # happen during the `eval`, but also during the
+ # `show_result` if something takes too long.
+ # (Try/catch is around the whole block, because we want to
+ # prevent that a Control-C keypress terminates the REPL in
+ # any case.)
+ self._handle_keyboard_interrupt(e)
+
+ def _get_output_printer(self) -> OutputPrinter:
+ return OutputPrinter(
+ output=self.app.output,
+ input=self.app.input,
+ style=self._current_style,
+ style_transformation=self.style_transformation,
+ title=self.title,
+ )
+
+ def _show_result(self, result: object) -> None:
+ self._get_output_printer().display_result(
+ result=result,
+ out_prompt=self.get_output_prompt(),
+ reformat=self.enable_output_formatting,
+ highlight=self.enable_syntax_highlighting,
+ paginate=self.enable_pager,
+ )
def run(self) -> None:
"""
@@ -102,44 +140,78 @@ class PythonRepl(PythonInput):
try:
while True:
+ # Pull text from the user.
try:
- # Read.
- try:
- text = self.read()
- except EOFError:
- return
-
- # Eval.
- try:
- result = self.eval(text)
- except KeyboardInterrupt as e: # KeyboardInterrupt doesn't inherit from Exception.
- raise
- except SystemExit:
- return
- except BaseException as e:
- self._handle_exception(e)
- else:
- # Print.
- if result is not None:
- self.show_result(result)
-
- # Loop.
- self.current_statement_index += 1
- self.signatures = []
-
- except KeyboardInterrupt as e:
- # Handle all possible `KeyboardInterrupt` errors. This can
- # happen during the `eval`, but also during the
- # `show_result` if something takes too long.
- # (Try/catch is around the whole block, because we want to
- # prevent that a Control-C keypress terminates the REPL in
- # any case.)
- self._handle_keyboard_interrupt(e)
+ text = self.read()
+ except EOFError:
+ return
+ except BaseException:
+ # Something went wrong while reading input.
+ # (E.g., a bug in the completer that propagates. Don't
+ # crash the REPL.)
+ traceback.print_exc()
+ continue
+
+ # Run it; display the result (or errors if applicable).
+ self.run_and_show_expression(text)
finally:
if self.terminal_title:
clear_title()
self._remove_from_namespace()
+ async def run_and_show_expression_async(self, text: str) -> Any:
+ loop = asyncio.get_running_loop()
+ system_exit: SystemExit | None = None
+
+ try:
+ try:
+ # Create `eval` task. Ensure that control-c will cancel this
+ # task.
+ async def eval() -> Any:
+ nonlocal system_exit
+ try:
+ return await self.eval_async(text)
+ except SystemExit as e:
+ # Don't propagate SystemExit in `create_task()`. That
+ # will kill the event loop. We want to handle it
+ # gracefully.
+ system_exit = e
+
+ task = asyncio.create_task(eval())
+ loop.add_signal_handler(signal.SIGINT, lambda *_: task.cancel())
+ result = await task
+
+ if system_exit is not None:
+ raise system_exit
+ except KeyboardInterrupt:
+ # KeyboardInterrupt doesn't inherit from Exception.
+ raise
+ except SystemExit:
+ raise
+ except BaseException as e:
+ self._handle_exception(e)
+ else:
+ # Print.
+ if result is not None:
+ await loop.run_in_executor(None, lambda: self._show_result(result))
+
+ # Loop.
+ self.current_statement_index += 1
+ self.signatures = []
+ # Return the result for future consumers.
+ return result
+ finally:
+ loop.remove_signal_handler(signal.SIGINT)
+
+ except KeyboardInterrupt as e:
+ # Handle all possible `KeyboardInterrupt` errors. This can
+ # happen during the `eval`, but also during the
+ # `show_result` if something takes too long.
+ # (Try/catch is around the whole block, because we want to
+ # prevent that a Control-C keypress terminates the REPL in
+ # any case.)
+ self._handle_keyboard_interrupt(e)
+
async def run_async(self) -> None:
"""
Run the REPL loop, but run the blocking parts in an executor, so that
@@ -152,7 +224,7 @@ class PythonRepl(PythonInput):
(Both for control-C to work, as well as for the code to see the right
thread in which it was embedded).
"""
- loop = asyncio.get_event_loop()
+ loop = asyncio.get_running_loop()
if self.terminal_title:
set_title(self.terminal_title)
@@ -167,32 +239,23 @@ class PythonRepl(PythonInput):
text = await loop.run_in_executor(None, self.read)
except EOFError:
return
+ except BaseException:
+ # Something went wrong while reading input.
+ # (E.g., a bug in the completer that propagates. Don't
+ # crash the REPL.)
+ traceback.print_exc()
+ continue
# Eval.
- try:
- result = await self.eval_async(text)
- except KeyboardInterrupt as e: # KeyboardInterrupt doesn't inherit from Exception.
- raise
- except SystemExit:
- return
- except BaseException as e:
- self._handle_exception(e)
- else:
- # Print.
- if result is not None:
- await loop.run_in_executor(
- None, lambda: self.show_result(result)
- )
-
- # Loop.
- self.current_statement_index += 1
- self.signatures = []
+ await self.run_and_show_expression_async(text)
except KeyboardInterrupt as e:
# XXX: This does not yet work properly. In some situations,
# `KeyboardInterrupt` exceptions can end up in the event
# loop selector.
self._handle_keyboard_interrupt(e)
+ except SystemExit:
+ return
finally:
if self.terminal_title:
clear_title()
@@ -221,7 +284,7 @@ class PythonRepl(PythonInput):
result = eval(code, self.get_globals(), self.get_locals())
if _has_coroutine_flag(code):
- result = asyncio.get_event_loop().run_until_complete(result)
+ result = asyncio.get_running_loop().run_until_complete(result)
self._store_eval_result(result)
return result
@@ -231,7 +294,10 @@ class PythonRepl(PythonInput):
# above, then `sys.exc_info()` would not report the right error.
# See issue: https://github.com/prompt-toolkit/ptpython/issues/435
code = self._compile_with_flags(line, "exec")
- exec(code, self.get_globals(), self.get_locals())
+ result = eval(code, self.get_globals(), self.get_locals())
+
+ if _has_coroutine_flag(code):
+ result = asyncio.get_running_loop().run_until_complete(result)
return None
@@ -263,21 +329,26 @@ class PythonRepl(PythonInput):
self._store_eval_result(result)
return result
- # If not a valid `eval` expression, run using `exec` instead.
+ # If not a valid `eval` expression, compile as `exec` expression
+ # but still run with eval to get an awaitable in case of a
+ # awaitable expression.
code = self._compile_with_flags(line, "exec")
- exec(code, self.get_globals(), self.get_locals())
+ result = eval(code, self.get_globals(), self.get_locals())
+
+ if _has_coroutine_flag(code):
+ result = await result
return None
def _store_eval_result(self, result: object) -> None:
- locals: Dict[str, Any] = self.get_locals()
+ locals: dict[str, Any] = self.get_locals()
locals["_"] = locals["_%i" % self.current_statement_index] = result
def get_compiler_flags(self) -> int:
return super().get_compiler_flags() | PyCF_ALLOW_TOP_LEVEL_AWAIT
def _compile_with_flags(self, code: str, mode: str):
- " Compile code with the right compiler flags. "
+ "Compile code with the right compiler flags."
return compile(
code,
"<stdin>",
@@ -286,257 +357,13 @@ class PythonRepl(PythonInput):
dont_inherit=True,
)
- def show_result(self, result: object) -> None:
- """
- Show __repr__ for an `eval` result.
-
- Note: this can raise `KeyboardInterrupt` if either calling `__repr__`,
- `__pt_repr__` or formatting the output with "Black" takes to long
- and the user presses Control-C.
- """
- out_prompt = to_formatted_text(self.get_output_prompt())
-
- # If the repr is valid Python code, use the Pygments lexer.
- try:
- result_repr = repr(result)
- except KeyboardInterrupt:
- raise # Don't catch here.
- except BaseException as e:
- # Calling repr failed.
- self._handle_exception(e)
- return
-
- try:
- compile(result_repr, "", "eval")
- except SyntaxError:
- formatted_result_repr = to_formatted_text(result_repr)
- else:
- # Syntactically correct. Format with black and syntax highlight.
- if self.enable_output_formatting:
- # Inline import. Slightly speed up start-up time if black is
- # not used.
- import black
-
- result_repr = black.format_str(
- result_repr,
- mode=black.FileMode(line_length=self.app.output.get_size().columns),
- )
-
- formatted_result_repr = to_formatted_text(
- PygmentsTokens(list(_lex_python_result(result_repr)))
- )
-
- # If __pt_repr__ is present, take this. This can return prompt_toolkit
- # formatted text.
- try:
- if hasattr(result, "__pt_repr__"):
- formatted_result_repr = to_formatted_text(
- getattr(result, "__pt_repr__")()
- )
- if isinstance(formatted_result_repr, list):
- formatted_result_repr = FormattedText(formatted_result_repr)
- except KeyboardInterrupt:
- raise # Don't catch here.
- except:
- # For bad code, `__getattr__` can raise something that's not an
- # `AttributeError`. This happens already when calling `hasattr()`.
- pass
-
- # Align every line to the prompt.
- line_sep = "\n" + " " * fragment_list_width(out_prompt)
- indented_repr: StyleAndTextTuples = []
-
- lines = list(split_lines(formatted_result_repr))
-
- for i, fragment in enumerate(lines):
- indented_repr.extend(fragment)
-
- # Add indentation separator between lines, not after the last line.
- if i != len(lines) - 1:
- indented_repr.append(("", line_sep))
-
- # Write output tokens.
- if self.enable_syntax_highlighting:
- formatted_output = merge_formatted_text([out_prompt, indented_repr])
- else:
- formatted_output = FormattedText(
- out_prompt + [("", fragment_list_to_text(formatted_result_repr))]
- )
-
- if self.enable_pager:
- self.print_paginated_formatted_text(to_formatted_text(formatted_output))
- else:
- self.print_formatted_text(to_formatted_text(formatted_output))
-
- self.app.output.flush()
-
- if self.insert_blank_line_after_output:
- self.app.output.write("\n")
-
- def print_formatted_text(
- self, formatted_text: StyleAndTextTuples, end: str = "\n"
- ) -> None:
- print_formatted_text(
- FormattedText(formatted_text),
- style=self._current_style,
- style_transformation=self.style_transformation,
- include_default_pygments_style=False,
- output=self.app.output,
- end=end,
- )
-
- def print_paginated_formatted_text(
- self,
- formatted_text: StyleAndTextTuples,
- end: str = "\n",
- ) -> None:
- """
- Print formatted text, using --MORE-- style pagination.
- (Avoid filling up the terminal's scrollback buffer.)
- """
- pager_prompt = self.create_pager_prompt()
- size = self.app.output.get_size()
-
- abort = False
- print_all = False
-
- # Max number of lines allowed in the buffer before painting.
- max_rows = size.rows - 1
-
- # Page buffer.
- rows_in_buffer = 0
- columns_in_buffer = 0
- page: StyleAndTextTuples = []
-
- def flush_page() -> None:
- nonlocal page, columns_in_buffer, rows_in_buffer
- self.print_formatted_text(page, end="")
- page = []
- columns_in_buffer = 0
- rows_in_buffer = 0
-
- def show_pager() -> None:
- nonlocal abort, max_rows, print_all
-
- # Run pager prompt in another thread.
- # Same as for the input. This prevents issues with nested event
- # loops.
- pager_result = None
-
- def in_thread() -> None:
- nonlocal pager_result
- pager_result = pager_prompt.prompt()
-
- th = threading.Thread(target=in_thread)
- th.start()
- th.join()
-
- if pager_result == PagerResult.ABORT:
- print("...")
- abort = True
-
- elif pager_result == PagerResult.NEXT_LINE:
- max_rows = 1
-
- elif pager_result == PagerResult.NEXT_PAGE:
- max_rows = size.rows - 1
-
- elif pager_result == PagerResult.PRINT_ALL:
- print_all = True
-
- # Loop over lines. Show --MORE-- prompt when page is filled.
-
- formatted_text = formatted_text + [("", end)]
- lines = list(split_lines(formatted_text))
-
- for lineno, line in enumerate(lines):
- for style, text, *_ in line:
- for c in text:
- width = get_cwidth(c)
-
- # (Soft) wrap line if it doesn't fit.
- if columns_in_buffer + width > size.columns:
- # Show pager first if we get too many lines after
- # wrapping.
- if rows_in_buffer + 1 >= max_rows and not print_all:
- page.append(("", "\n"))
- flush_page()
- show_pager()
- if abort:
- return
-
- rows_in_buffer += 1
- columns_in_buffer = 0
-
- columns_in_buffer += width
- page.append((style, c))
-
- if rows_in_buffer + 1 >= max_rows and not print_all:
- page.append(("", "\n"))
- flush_page()
- show_pager()
- if abort:
- return
- else:
- # Add line ending between lines (if `end="\n"` was given, one
- # more empty line is added in `split_lines` automatically to
- # take care of the final line ending).
- if lineno != len(lines) - 1:
- page.append(("", "\n"))
- rows_in_buffer += 1
- columns_in_buffer = 0
-
- flush_page()
-
- def create_pager_prompt(self) -> PromptSession["PagerResult"]:
- """
- Create pager --MORE-- prompt.
- """
- return create_pager_prompt(self._current_style, self.title)
-
def _handle_exception(self, e: BaseException) -> None:
- output = self.app.output
-
- # Instead of just calling ``traceback.format_exc``, we take the
- # traceback and skip the bottom calls of this framework.
- t, v, tb = sys.exc_info()
-
- # Required for pdb.post_mortem() to work.
- sys.last_type, sys.last_value, sys.last_traceback = t, v, tb
-
- tblist = list(traceback.extract_tb(tb))
-
- for line_nr, tb_tuple in enumerate(tblist):
- if tb_tuple[0] == "<stdin>":
- tblist = tblist[line_nr:]
- break
-
- l = traceback.format_list(tblist)
- if l:
- l.insert(0, "Traceback (most recent call last):\n")
- l.extend(traceback.format_exception_only(t, v))
-
- tb_str = "".join(l)
-
- # Format exception and write to output.
- # (We use the default style. Most other styles result
- # in unreadable colors for the traceback.)
- if self.enable_syntax_highlighting:
- tokens = list(_lex_python_traceback(tb_str))
- else:
- tokens = [(Token, tb_str)]
-
- print_formatted_text(
- PygmentsTokens(tokens),
- style=self._current_style,
- style_transformation=self.style_transformation,
- include_default_pygments_style=False,
- output=output,
+ self._get_output_printer().display_exception(
+ e,
+ highlight=self.enable_syntax_highlighting,
+ paginate=self.enable_pager,
)
- output.write("%s\n" % e)
- output.flush()
-
def _handle_keyboard_interrupt(self, e: KeyboardInterrupt) -> None:
output = self.app.output
@@ -562,21 +389,16 @@ class PythonRepl(PythonInput):
globals = self.get_globals()
del globals["get_ptpython"]
-
-def _lex_python_traceback(tb):
- " Return token list for traceback string. "
- lexer = PythonTracebackLexer()
- return lexer.get_tokens(tb)
-
-
-def _lex_python_result(tb):
- " Return token list for Python string. "
- lexer = PythonLexer()
- # Use `get_tokens_unprocessed`, so that we get exactly the same string,
- # without line endings appended. `print_formatted_text` already appends a
- # line ending, and otherwise we'll have two line endings.
- tokens = lexer.get_tokens_unprocessed(tb)
- return [(tokentype, value) for index, tokentype, value in tokens]
+ def print_paginated_formatted_text(
+ self,
+ formatted_text: Iterable[OneStyleAndTextTuple],
+ end: str = "\n",
+ ) -> None:
+ # Warning: This is mainly here backwards-compatibility. Some projects
+ # call `print_paginated_formatted_text` on the Repl object.
+ self._get_output_printer().display_style_and_text_tuples(
+ formatted_text, paginate=True
+ )
def enable_deprecation_warnings() -> None:
@@ -590,28 +412,36 @@ def enable_deprecation_warnings() -> None:
warnings.filterwarnings("default", category=DeprecationWarning, module="__main__")
-def run_config(repl: PythonInput, config_file: str = "~/.ptpython/config.py") -> None:
+DEFAULT_CONFIG_FILE = "~/.config/ptpython/config.py"
+
+
+def run_config(repl: PythonInput, config_file: str | None = None) -> None:
"""
Execute REPL config file.
:param repl: `PythonInput` instance.
:param config_file: Path of the configuration file.
"""
+ explicit_config_file = config_file is not None
+
# Expand tildes.
- config_file = os.path.expanduser(config_file)
+ config_file = os.path.expanduser(
+ config_file if config_file is not None else DEFAULT_CONFIG_FILE
+ )
def enter_to_continue() -> None:
input("\nPress ENTER to continue...")
# Check whether this file exists.
if not os.path.exists(config_file):
- print("Impossible to read %r" % config_file)
- enter_to_continue()
+ if explicit_config_file:
+ print(f"Impossible to read {config_file}")
+ enter_to_continue()
return
# Run the config file in an empty namespace.
try:
- namespace: Dict[str, Any] = {}
+ namespace: dict[str, Any] = {}
with open(config_file, "rb") as f:
code = compile(f.read(), config_file, "exec")
@@ -630,10 +460,10 @@ def run_config(repl: PythonInput, config_file: str = "~/.ptpython/config.py") ->
def embed(
globals=None,
locals=None,
- configure: Optional[Callable[[PythonRepl], None]] = None,
+ configure: Callable[[PythonRepl], None] | None = None,
vi_mode: bool = False,
- history_filename: Optional[str] = None,
- title: Optional[str] = None,
+ history_filename: str | None = None,
+ title: str | None = None,
startup_paths=None,
patch_stdout: bool = False,
return_asyncio_coroutine: bool = False,
@@ -685,81 +515,17 @@ def embed(
configure(repl)
# Start repl.
- patch_context: ContextManager = (
+ patch_context: ContextManager[None] = (
patch_stdout_context() if patch_stdout else DummyContext()
)
if return_asyncio_coroutine:
- async def coroutine():
+ async def coroutine() -> None:
with patch_context:
await repl.run_async()
- return coroutine()
+ return coroutine() # type: ignore
else:
with patch_context:
repl.run()
-
-
-class PagerResult(Enum):
- ABORT = "ABORT"
- NEXT_LINE = "NEXT_LINE"
- NEXT_PAGE = "NEXT_PAGE"
- PRINT_ALL = "PRINT_ALL"
-
-
-def create_pager_prompt(
- style: BaseStyle, title: AnyFormattedText = ""
-) -> PromptSession[PagerResult]:
- """
- Create a "continue" prompt for paginated output.
- """
- bindings = KeyBindings()
-
- @bindings.add("enter")
- @bindings.add("down")
- def next_line(event: KeyPressEvent) -> None:
- event.app.exit(result=PagerResult.NEXT_LINE)
-
- @bindings.add("space")
- def next_page(event: KeyPressEvent) -> None:
- event.app.exit(result=PagerResult.NEXT_PAGE)
-
- @bindings.add("a")
- def print_all(event: KeyPressEvent) -> None:
- event.app.exit(result=PagerResult.PRINT_ALL)
-
- @bindings.add("q")
- @bindings.add("c-c")
- @bindings.add("c-d")
- @bindings.add("escape", eager=True)
- def no(event: KeyPressEvent) -> None:
- event.app.exit(result=PagerResult.ABORT)
-
- @bindings.add("<any>")
- def _(event: KeyPressEvent) -> None:
- " Disallow inserting other text. "
- pass
-
- style
-
- session: PromptSession[PagerResult] = PromptSession(
- merge_formatted_text(
- [
- title,
- HTML(
- "<status-toolbar>"
- "<more> -- MORE -- </more> "
- "<key>[Enter]</key> Scroll "
- "<key>[Space]</key> Next page "
- "<key>[a]</key> Print all "
- "<key>[q]</key> Quit "
- "</status-toolbar>: "
- ),
- ]
- ),
- key_bindings=bindings,
- erase_when_done=True,
- style=style,
- )
- return session