diff options
Diffstat (limited to 'ptpython/repl.py')
-rw-r--r-- | ptpython/repl.py | 608 |
1 files changed, 187 insertions, 421 deletions
diff --git a/ptpython/repl.py b/ptpython/repl.py index ae7b1d0..fc9b9da 100644 --- a/ptpython/repl.py +++ b/ptpython/repl.py @@ -7,44 +7,32 @@ Utility for creating a Python repl. embed(globals(), locals(), vi_mode=False) """ +from __future__ import annotations + import asyncio import builtins import os +import signal import sys -import threading import traceback import types import warnings from dis import COMPILER_FLAG_NAMES -from enum import Enum -from typing import Any, Callable, ContextManager, Dict, Optional - -from prompt_toolkit.formatted_text import ( - HTML, - AnyFormattedText, - FormattedText, - PygmentsTokens, - StyleAndTextTuples, - fragment_list_width, - merge_formatted_text, - to_formatted_text, -) -from prompt_toolkit.formatted_text.utils import fragment_list_to_text, split_lines -from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent +from typing import Any, Callable, ContextManager, Iterable + +from prompt_toolkit.formatted_text import OneStyleAndTextTuple from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context from prompt_toolkit.shortcuts import ( - PromptSession, clear_title, - print_formatted_text, set_title, ) -from prompt_toolkit.styles import BaseStyle -from prompt_toolkit.utils import DummyContext, get_cwidth -from pygments.lexers import PythonLexer, PythonTracebackLexer -from pygments.token import Token +from prompt_toolkit.utils import DummyContext +from pygments.lexers import PythonTracebackLexer # noqa: F401 +from .printer import OutputPrinter from .python_input import PythonInput +PyCF_ALLOW_TOP_LEVEL_AWAIT: int try: from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT # type: ignore except ImportError: @@ -53,7 +41,7 @@ except ImportError: __all__ = ["PythonRepl", "enable_deprecation_warnings", "run_config", "embed"] -def _get_coroutine_flag() -> Optional[int]: +def _get_coroutine_flag() -> int | None: for k, v in COMPILER_FLAG_NAMES.items(): if v == "COROUTINE": return k @@ -62,7 +50,7 @@ def _get_coroutine_flag() -> Optional[int]: return None -COROUTINE_FLAG: Optional[int] = _get_coroutine_flag() +COROUTINE_FLAG: int | None = _get_coroutine_flag() def _has_coroutine_flag(code: types.CodeType) -> bool: @@ -80,7 +68,7 @@ class PythonRepl(PythonInput): self._load_start_paths() def _load_start_paths(self) -> None: - " Start the Read-Eval-Print Loop. " + "Start the Read-Eval-Print Loop." if self._startup_paths: for path in self._startup_paths: if os.path.exists(path): @@ -89,7 +77,57 @@ class PythonRepl(PythonInput): exec(code, self.get_globals(), self.get_locals()) else: output = self.app.output - output.write("WARNING | File not found: {}\n\n".format(path)) + output.write(f"WARNING | File not found: {path}\n\n") + + def run_and_show_expression(self, expression: str) -> None: + try: + # Eval. + try: + result = self.eval(expression) + except KeyboardInterrupt: + # KeyboardInterrupt doesn't inherit from Exception. + raise + except SystemExit: + raise + except BaseException as e: + self._handle_exception(e) + else: + # Print. + if result is not None: + self._show_result(result) + if self.insert_blank_line_after_output: + self.app.output.write("\n") + + # Loop. + self.current_statement_index += 1 + self.signatures = [] + + except KeyboardInterrupt as e: + # Handle all possible `KeyboardInterrupt` errors. This can + # happen during the `eval`, but also during the + # `show_result` if something takes too long. + # (Try/catch is around the whole block, because we want to + # prevent that a Control-C keypress terminates the REPL in + # any case.) + self._handle_keyboard_interrupt(e) + + def _get_output_printer(self) -> OutputPrinter: + return OutputPrinter( + output=self.app.output, + input=self.app.input, + style=self._current_style, + style_transformation=self.style_transformation, + title=self.title, + ) + + def _show_result(self, result: object) -> None: + self._get_output_printer().display_result( + result=result, + out_prompt=self.get_output_prompt(), + reformat=self.enable_output_formatting, + highlight=self.enable_syntax_highlighting, + paginate=self.enable_pager, + ) def run(self) -> None: """ @@ -102,44 +140,78 @@ class PythonRepl(PythonInput): try: while True: + # Pull text from the user. try: - # Read. - try: - text = self.read() - except EOFError: - return - - # Eval. - try: - result = self.eval(text) - except KeyboardInterrupt as e: # KeyboardInterrupt doesn't inherit from Exception. - raise - except SystemExit: - return - except BaseException as e: - self._handle_exception(e) - else: - # Print. - if result is not None: - self.show_result(result) - - # Loop. - self.current_statement_index += 1 - self.signatures = [] - - except KeyboardInterrupt as e: - # Handle all possible `KeyboardInterrupt` errors. This can - # happen during the `eval`, but also during the - # `show_result` if something takes too long. - # (Try/catch is around the whole block, because we want to - # prevent that a Control-C keypress terminates the REPL in - # any case.) - self._handle_keyboard_interrupt(e) + text = self.read() + except EOFError: + return + except BaseException: + # Something went wrong while reading input. + # (E.g., a bug in the completer that propagates. Don't + # crash the REPL.) + traceback.print_exc() + continue + + # Run it; display the result (or errors if applicable). + self.run_and_show_expression(text) finally: if self.terminal_title: clear_title() self._remove_from_namespace() + async def run_and_show_expression_async(self, text: str) -> Any: + loop = asyncio.get_running_loop() + system_exit: SystemExit | None = None + + try: + try: + # Create `eval` task. Ensure that control-c will cancel this + # task. + async def eval() -> Any: + nonlocal system_exit + try: + return await self.eval_async(text) + except SystemExit as e: + # Don't propagate SystemExit in `create_task()`. That + # will kill the event loop. We want to handle it + # gracefully. + system_exit = e + + task = asyncio.create_task(eval()) + loop.add_signal_handler(signal.SIGINT, lambda *_: task.cancel()) + result = await task + + if system_exit is not None: + raise system_exit + except KeyboardInterrupt: + # KeyboardInterrupt doesn't inherit from Exception. + raise + except SystemExit: + raise + except BaseException as e: + self._handle_exception(e) + else: + # Print. + if result is not None: + await loop.run_in_executor(None, lambda: self._show_result(result)) + + # Loop. + self.current_statement_index += 1 + self.signatures = [] + # Return the result for future consumers. + return result + finally: + loop.remove_signal_handler(signal.SIGINT) + + except KeyboardInterrupt as e: + # Handle all possible `KeyboardInterrupt` errors. This can + # happen during the `eval`, but also during the + # `show_result` if something takes too long. + # (Try/catch is around the whole block, because we want to + # prevent that a Control-C keypress terminates the REPL in + # any case.) + self._handle_keyboard_interrupt(e) + async def run_async(self) -> None: """ Run the REPL loop, but run the blocking parts in an executor, so that @@ -152,7 +224,7 @@ class PythonRepl(PythonInput): (Both for control-C to work, as well as for the code to see the right thread in which it was embedded). """ - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() if self.terminal_title: set_title(self.terminal_title) @@ -167,32 +239,23 @@ class PythonRepl(PythonInput): text = await loop.run_in_executor(None, self.read) except EOFError: return + except BaseException: + # Something went wrong while reading input. + # (E.g., a bug in the completer that propagates. Don't + # crash the REPL.) + traceback.print_exc() + continue # Eval. - try: - result = await self.eval_async(text) - except KeyboardInterrupt as e: # KeyboardInterrupt doesn't inherit from Exception. - raise - except SystemExit: - return - except BaseException as e: - self._handle_exception(e) - else: - # Print. - if result is not None: - await loop.run_in_executor( - None, lambda: self.show_result(result) - ) - - # Loop. - self.current_statement_index += 1 - self.signatures = [] + await self.run_and_show_expression_async(text) except KeyboardInterrupt as e: # XXX: This does not yet work properly. In some situations, # `KeyboardInterrupt` exceptions can end up in the event # loop selector. self._handle_keyboard_interrupt(e) + except SystemExit: + return finally: if self.terminal_title: clear_title() @@ -221,7 +284,7 @@ class PythonRepl(PythonInput): result = eval(code, self.get_globals(), self.get_locals()) if _has_coroutine_flag(code): - result = asyncio.get_event_loop().run_until_complete(result) + result = asyncio.get_running_loop().run_until_complete(result) self._store_eval_result(result) return result @@ -231,7 +294,10 @@ class PythonRepl(PythonInput): # above, then `sys.exc_info()` would not report the right error. # See issue: https://github.com/prompt-toolkit/ptpython/issues/435 code = self._compile_with_flags(line, "exec") - exec(code, self.get_globals(), self.get_locals()) + result = eval(code, self.get_globals(), self.get_locals()) + + if _has_coroutine_flag(code): + result = asyncio.get_running_loop().run_until_complete(result) return None @@ -263,21 +329,26 @@ class PythonRepl(PythonInput): self._store_eval_result(result) return result - # If not a valid `eval` expression, run using `exec` instead. + # If not a valid `eval` expression, compile as `exec` expression + # but still run with eval to get an awaitable in case of a + # awaitable expression. code = self._compile_with_flags(line, "exec") - exec(code, self.get_globals(), self.get_locals()) + result = eval(code, self.get_globals(), self.get_locals()) + + if _has_coroutine_flag(code): + result = await result return None def _store_eval_result(self, result: object) -> None: - locals: Dict[str, Any] = self.get_locals() + locals: dict[str, Any] = self.get_locals() locals["_"] = locals["_%i" % self.current_statement_index] = result def get_compiler_flags(self) -> int: return super().get_compiler_flags() | PyCF_ALLOW_TOP_LEVEL_AWAIT def _compile_with_flags(self, code: str, mode: str): - " Compile code with the right compiler flags. " + "Compile code with the right compiler flags." return compile( code, "<stdin>", @@ -286,257 +357,13 @@ class PythonRepl(PythonInput): dont_inherit=True, ) - def show_result(self, result: object) -> None: - """ - Show __repr__ for an `eval` result. - - Note: this can raise `KeyboardInterrupt` if either calling `__repr__`, - `__pt_repr__` or formatting the output with "Black" takes to long - and the user presses Control-C. - """ - out_prompt = to_formatted_text(self.get_output_prompt()) - - # If the repr is valid Python code, use the Pygments lexer. - try: - result_repr = repr(result) - except KeyboardInterrupt: - raise # Don't catch here. - except BaseException as e: - # Calling repr failed. - self._handle_exception(e) - return - - try: - compile(result_repr, "", "eval") - except SyntaxError: - formatted_result_repr = to_formatted_text(result_repr) - else: - # Syntactically correct. Format with black and syntax highlight. - if self.enable_output_formatting: - # Inline import. Slightly speed up start-up time if black is - # not used. - import black - - result_repr = black.format_str( - result_repr, - mode=black.FileMode(line_length=self.app.output.get_size().columns), - ) - - formatted_result_repr = to_formatted_text( - PygmentsTokens(list(_lex_python_result(result_repr))) - ) - - # If __pt_repr__ is present, take this. This can return prompt_toolkit - # formatted text. - try: - if hasattr(result, "__pt_repr__"): - formatted_result_repr = to_formatted_text( - getattr(result, "__pt_repr__")() - ) - if isinstance(formatted_result_repr, list): - formatted_result_repr = FormattedText(formatted_result_repr) - except KeyboardInterrupt: - raise # Don't catch here. - except: - # For bad code, `__getattr__` can raise something that's not an - # `AttributeError`. This happens already when calling `hasattr()`. - pass - - # Align every line to the prompt. - line_sep = "\n" + " " * fragment_list_width(out_prompt) - indented_repr: StyleAndTextTuples = [] - - lines = list(split_lines(formatted_result_repr)) - - for i, fragment in enumerate(lines): - indented_repr.extend(fragment) - - # Add indentation separator between lines, not after the last line. - if i != len(lines) - 1: - indented_repr.append(("", line_sep)) - - # Write output tokens. - if self.enable_syntax_highlighting: - formatted_output = merge_formatted_text([out_prompt, indented_repr]) - else: - formatted_output = FormattedText( - out_prompt + [("", fragment_list_to_text(formatted_result_repr))] - ) - - if self.enable_pager: - self.print_paginated_formatted_text(to_formatted_text(formatted_output)) - else: - self.print_formatted_text(to_formatted_text(formatted_output)) - - self.app.output.flush() - - if self.insert_blank_line_after_output: - self.app.output.write("\n") - - def print_formatted_text( - self, formatted_text: StyleAndTextTuples, end: str = "\n" - ) -> None: - print_formatted_text( - FormattedText(formatted_text), - style=self._current_style, - style_transformation=self.style_transformation, - include_default_pygments_style=False, - output=self.app.output, - end=end, - ) - - def print_paginated_formatted_text( - self, - formatted_text: StyleAndTextTuples, - end: str = "\n", - ) -> None: - """ - Print formatted text, using --MORE-- style pagination. - (Avoid filling up the terminal's scrollback buffer.) - """ - pager_prompt = self.create_pager_prompt() - size = self.app.output.get_size() - - abort = False - print_all = False - - # Max number of lines allowed in the buffer before painting. - max_rows = size.rows - 1 - - # Page buffer. - rows_in_buffer = 0 - columns_in_buffer = 0 - page: StyleAndTextTuples = [] - - def flush_page() -> None: - nonlocal page, columns_in_buffer, rows_in_buffer - self.print_formatted_text(page, end="") - page = [] - columns_in_buffer = 0 - rows_in_buffer = 0 - - def show_pager() -> None: - nonlocal abort, max_rows, print_all - - # Run pager prompt in another thread. - # Same as for the input. This prevents issues with nested event - # loops. - pager_result = None - - def in_thread() -> None: - nonlocal pager_result - pager_result = pager_prompt.prompt() - - th = threading.Thread(target=in_thread) - th.start() - th.join() - - if pager_result == PagerResult.ABORT: - print("...") - abort = True - - elif pager_result == PagerResult.NEXT_LINE: - max_rows = 1 - - elif pager_result == PagerResult.NEXT_PAGE: - max_rows = size.rows - 1 - - elif pager_result == PagerResult.PRINT_ALL: - print_all = True - - # Loop over lines. Show --MORE-- prompt when page is filled. - - formatted_text = formatted_text + [("", end)] - lines = list(split_lines(formatted_text)) - - for lineno, line in enumerate(lines): - for style, text, *_ in line: - for c in text: - width = get_cwidth(c) - - # (Soft) wrap line if it doesn't fit. - if columns_in_buffer + width > size.columns: - # Show pager first if we get too many lines after - # wrapping. - if rows_in_buffer + 1 >= max_rows and not print_all: - page.append(("", "\n")) - flush_page() - show_pager() - if abort: - return - - rows_in_buffer += 1 - columns_in_buffer = 0 - - columns_in_buffer += width - page.append((style, c)) - - if rows_in_buffer + 1 >= max_rows and not print_all: - page.append(("", "\n")) - flush_page() - show_pager() - if abort: - return - else: - # Add line ending between lines (if `end="\n"` was given, one - # more empty line is added in `split_lines` automatically to - # take care of the final line ending). - if lineno != len(lines) - 1: - page.append(("", "\n")) - rows_in_buffer += 1 - columns_in_buffer = 0 - - flush_page() - - def create_pager_prompt(self) -> PromptSession["PagerResult"]: - """ - Create pager --MORE-- prompt. - """ - return create_pager_prompt(self._current_style, self.title) - def _handle_exception(self, e: BaseException) -> None: - output = self.app.output - - # Instead of just calling ``traceback.format_exc``, we take the - # traceback and skip the bottom calls of this framework. - t, v, tb = sys.exc_info() - - # Required for pdb.post_mortem() to work. - sys.last_type, sys.last_value, sys.last_traceback = t, v, tb - - tblist = list(traceback.extract_tb(tb)) - - for line_nr, tb_tuple in enumerate(tblist): - if tb_tuple[0] == "<stdin>": - tblist = tblist[line_nr:] - break - - l = traceback.format_list(tblist) - if l: - l.insert(0, "Traceback (most recent call last):\n") - l.extend(traceback.format_exception_only(t, v)) - - tb_str = "".join(l) - - # Format exception and write to output. - # (We use the default style. Most other styles result - # in unreadable colors for the traceback.) - if self.enable_syntax_highlighting: - tokens = list(_lex_python_traceback(tb_str)) - else: - tokens = [(Token, tb_str)] - - print_formatted_text( - PygmentsTokens(tokens), - style=self._current_style, - style_transformation=self.style_transformation, - include_default_pygments_style=False, - output=output, + self._get_output_printer().display_exception( + e, + highlight=self.enable_syntax_highlighting, + paginate=self.enable_pager, ) - output.write("%s\n" % e) - output.flush() - def _handle_keyboard_interrupt(self, e: KeyboardInterrupt) -> None: output = self.app.output @@ -562,21 +389,16 @@ class PythonRepl(PythonInput): globals = self.get_globals() del globals["get_ptpython"] - -def _lex_python_traceback(tb): - " Return token list for traceback string. " - lexer = PythonTracebackLexer() - return lexer.get_tokens(tb) - - -def _lex_python_result(tb): - " Return token list for Python string. " - lexer = PythonLexer() - # Use `get_tokens_unprocessed`, so that we get exactly the same string, - # without line endings appended. `print_formatted_text` already appends a - # line ending, and otherwise we'll have two line endings. - tokens = lexer.get_tokens_unprocessed(tb) - return [(tokentype, value) for index, tokentype, value in tokens] + def print_paginated_formatted_text( + self, + formatted_text: Iterable[OneStyleAndTextTuple], + end: str = "\n", + ) -> None: + # Warning: This is mainly here backwards-compatibility. Some projects + # call `print_paginated_formatted_text` on the Repl object. + self._get_output_printer().display_style_and_text_tuples( + formatted_text, paginate=True + ) def enable_deprecation_warnings() -> None: @@ -590,28 +412,36 @@ def enable_deprecation_warnings() -> None: warnings.filterwarnings("default", category=DeprecationWarning, module="__main__") -def run_config(repl: PythonInput, config_file: str = "~/.ptpython/config.py") -> None: +DEFAULT_CONFIG_FILE = "~/.config/ptpython/config.py" + + +def run_config(repl: PythonInput, config_file: str | None = None) -> None: """ Execute REPL config file. :param repl: `PythonInput` instance. :param config_file: Path of the configuration file. """ + explicit_config_file = config_file is not None + # Expand tildes. - config_file = os.path.expanduser(config_file) + config_file = os.path.expanduser( + config_file if config_file is not None else DEFAULT_CONFIG_FILE + ) def enter_to_continue() -> None: input("\nPress ENTER to continue...") # Check whether this file exists. if not os.path.exists(config_file): - print("Impossible to read %r" % config_file) - enter_to_continue() + if explicit_config_file: + print(f"Impossible to read {config_file}") + enter_to_continue() return # Run the config file in an empty namespace. try: - namespace: Dict[str, Any] = {} + namespace: dict[str, Any] = {} with open(config_file, "rb") as f: code = compile(f.read(), config_file, "exec") @@ -630,10 +460,10 @@ def run_config(repl: PythonInput, config_file: str = "~/.ptpython/config.py") -> def embed( globals=None, locals=None, - configure: Optional[Callable[[PythonRepl], None]] = None, + configure: Callable[[PythonRepl], None] | None = None, vi_mode: bool = False, - history_filename: Optional[str] = None, - title: Optional[str] = None, + history_filename: str | None = None, + title: str | None = None, startup_paths=None, patch_stdout: bool = False, return_asyncio_coroutine: bool = False, @@ -685,81 +515,17 @@ def embed( configure(repl) # Start repl. - patch_context: ContextManager = ( + patch_context: ContextManager[None] = ( patch_stdout_context() if patch_stdout else DummyContext() ) if return_asyncio_coroutine: - async def coroutine(): + async def coroutine() -> None: with patch_context: await repl.run_async() - return coroutine() + return coroutine() # type: ignore else: with patch_context: repl.run() - - -class PagerResult(Enum): - ABORT = "ABORT" - NEXT_LINE = "NEXT_LINE" - NEXT_PAGE = "NEXT_PAGE" - PRINT_ALL = "PRINT_ALL" - - -def create_pager_prompt( - style: BaseStyle, title: AnyFormattedText = "" -) -> PromptSession[PagerResult]: - """ - Create a "continue" prompt for paginated output. - """ - bindings = KeyBindings() - - @bindings.add("enter") - @bindings.add("down") - def next_line(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.NEXT_LINE) - - @bindings.add("space") - def next_page(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.NEXT_PAGE) - - @bindings.add("a") - def print_all(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.PRINT_ALL) - - @bindings.add("q") - @bindings.add("c-c") - @bindings.add("c-d") - @bindings.add("escape", eager=True) - def no(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.ABORT) - - @bindings.add("<any>") - def _(event: KeyPressEvent) -> None: - " Disallow inserting other text. " - pass - - style - - session: PromptSession[PagerResult] = PromptSession( - merge_formatted_text( - [ - title, - HTML( - "<status-toolbar>" - "<more> -- MORE -- </more> " - "<key>[Enter]</key> Scroll " - "<key>[Space]</key> Next page " - "<key>[a]</key> Print all " - "<key>[q]</key> Quit " - "</status-toolbar>: " - ), - ] - ), - key_bindings=bindings, - erase_when_done=True, - style=style, - ) - return session |