summaryrefslogtreecommitdiffstats
path: root/src/prompt_toolkit/contrib/telnet
diff options
context:
space:
mode:
Diffstat (limited to 'src/prompt_toolkit/contrib/telnet')
-rw-r--r--src/prompt_toolkit/contrib/telnet/__init__.py7
-rw-r--r--src/prompt_toolkit/contrib/telnet/log.py12
-rw-r--r--src/prompt_toolkit/contrib/telnet/protocol.py208
-rw-r--r--src/prompt_toolkit/contrib/telnet/server.py427
4 files changed, 654 insertions, 0 deletions
diff --git a/src/prompt_toolkit/contrib/telnet/__init__.py b/src/prompt_toolkit/contrib/telnet/__init__.py
new file mode 100644
index 0000000..de902b4
--- /dev/null
+++ b/src/prompt_toolkit/contrib/telnet/__init__.py
@@ -0,0 +1,7 @@
+from __future__ import annotations
+
+from .server import TelnetServer
+
+__all__ = [
+ "TelnetServer",
+]
diff --git a/src/prompt_toolkit/contrib/telnet/log.py b/src/prompt_toolkit/contrib/telnet/log.py
new file mode 100644
index 0000000..0fe8433
--- /dev/null
+++ b/src/prompt_toolkit/contrib/telnet/log.py
@@ -0,0 +1,12 @@
+"""
+Python logger for the telnet server.
+"""
+from __future__ import annotations
+
+import logging
+
+logger = logging.getLogger(__package__)
+
+__all__ = [
+ "logger",
+]
diff --git a/src/prompt_toolkit/contrib/telnet/protocol.py b/src/prompt_toolkit/contrib/telnet/protocol.py
new file mode 100644
index 0000000..4b90e98
--- /dev/null
+++ b/src/prompt_toolkit/contrib/telnet/protocol.py
@@ -0,0 +1,208 @@
+"""
+Parser for the Telnet protocol. (Not a complete implementation of the telnet
+specification, but sufficient for a command line interface.)
+
+Inspired by `Twisted.conch.telnet`.
+"""
+from __future__ import annotations
+
+import struct
+from typing import Callable, Generator
+
+from .log import logger
+
+__all__ = [
+ "TelnetProtocolParser",
+]
+
+
+def int2byte(number: int) -> bytes:
+ return bytes((number,))
+
+
+# Telnet constants.
+NOP = int2byte(0)
+SGA = int2byte(3)
+
+IAC = int2byte(255)
+DO = int2byte(253)
+DONT = int2byte(254)
+LINEMODE = int2byte(34)
+SB = int2byte(250)
+WILL = int2byte(251)
+WONT = int2byte(252)
+MODE = int2byte(1)
+SE = int2byte(240)
+ECHO = int2byte(1)
+NAWS = int2byte(31)
+LINEMODE = int2byte(34)
+SUPPRESS_GO_AHEAD = int2byte(3)
+
+TTYPE = int2byte(24)
+SEND = int2byte(1)
+IS = int2byte(0)
+
+DM = int2byte(242)
+BRK = int2byte(243)
+IP = int2byte(244)
+AO = int2byte(245)
+AYT = int2byte(246)
+EC = int2byte(247)
+EL = int2byte(248)
+GA = int2byte(249)
+
+
+class TelnetProtocolParser:
+ """
+ Parser for the Telnet protocol.
+ Usage::
+
+ def data_received(data):
+ print(data)
+
+ def size_received(rows, columns):
+ print(rows, columns)
+
+ p = TelnetProtocolParser(data_received, size_received)
+ p.feed(binary_data)
+ """
+
+ def __init__(
+ self,
+ data_received_callback: Callable[[bytes], None],
+ size_received_callback: Callable[[int, int], None],
+ ttype_received_callback: Callable[[str], None],
+ ) -> None:
+ self.data_received_callback = data_received_callback
+ self.size_received_callback = size_received_callback
+ self.ttype_received_callback = ttype_received_callback
+
+ self._parser = self._parse_coroutine()
+ self._parser.send(None) # type: ignore
+
+ def received_data(self, data: bytes) -> None:
+ self.data_received_callback(data)
+
+ def do_received(self, data: bytes) -> None:
+ """Received telnet DO command."""
+ logger.info("DO %r", data)
+
+ def dont_received(self, data: bytes) -> None:
+ """Received telnet DONT command."""
+ logger.info("DONT %r", data)
+
+ def will_received(self, data: bytes) -> None:
+ """Received telnet WILL command."""
+ logger.info("WILL %r", data)
+
+ def wont_received(self, data: bytes) -> None:
+ """Received telnet WONT command."""
+ logger.info("WONT %r", data)
+
+ def command_received(self, command: bytes, data: bytes) -> None:
+ if command == DO:
+ self.do_received(data)
+
+ elif command == DONT:
+ self.dont_received(data)
+
+ elif command == WILL:
+ self.will_received(data)
+
+ elif command == WONT:
+ self.wont_received(data)
+
+ else:
+ logger.info("command received %r %r", command, data)
+
+ def naws(self, data: bytes) -> None:
+ """
+ Received NAWS. (Window dimensions.)
+ """
+ if len(data) == 4:
+ # NOTE: the first parameter of struct.unpack should be
+ # a 'str' object. Both on Py2/py3. This crashes on OSX
+ # otherwise.
+ columns, rows = struct.unpack("!HH", data)
+ self.size_received_callback(rows, columns)
+ else:
+ logger.warning("Wrong number of NAWS bytes")
+
+ def ttype(self, data: bytes) -> None:
+ """
+ Received terminal type.
+ """
+ subcmd, data = data[0:1], data[1:]
+ if subcmd == IS:
+ ttype = data.decode("ascii")
+ self.ttype_received_callback(ttype)
+ else:
+ logger.warning("Received a non-IS terminal type Subnegotiation")
+
+ def negotiate(self, data: bytes) -> None:
+ """
+ Got negotiate data.
+ """
+ command, payload = data[0:1], data[1:]
+
+ if command == NAWS:
+ self.naws(payload)
+ elif command == TTYPE:
+ self.ttype(payload)
+ else:
+ logger.info("Negotiate (%r got bytes)", len(data))
+
+ def _parse_coroutine(self) -> Generator[None, bytes, None]:
+ """
+ Parser state machine.
+ Every 'yield' expression returns the next byte.
+ """
+ while True:
+ d = yield
+
+ if d == int2byte(0):
+ pass # NOP
+
+ # Go to state escaped.
+ elif d == IAC:
+ d2 = yield
+
+ if d2 == IAC:
+ self.received_data(d2)
+
+ # Handle simple commands.
+ elif d2 in (NOP, DM, BRK, IP, AO, AYT, EC, EL, GA):
+ self.command_received(d2, b"")
+
+ # Handle IAC-[DO/DONT/WILL/WONT] commands.
+ elif d2 in (DO, DONT, WILL, WONT):
+ d3 = yield
+ self.command_received(d2, d3)
+
+ # Subnegotiation
+ elif d2 == SB:
+ # Consume everything until next IAC-SE
+ data = []
+
+ while True:
+ d3 = yield
+
+ if d3 == IAC:
+ d4 = yield
+ if d4 == SE:
+ break
+ else:
+ data.append(d4)
+ else:
+ data.append(d3)
+
+ self.negotiate(b"".join(data))
+ else:
+ self.received_data(d)
+
+ def feed(self, data: bytes) -> None:
+ """
+ Feed data to the parser.
+ """
+ for b in data:
+ self._parser.send(int2byte(b))
diff --git a/src/prompt_toolkit/contrib/telnet/server.py b/src/prompt_toolkit/contrib/telnet/server.py
new file mode 100644
index 0000000..9ebe66c
--- /dev/null
+++ b/src/prompt_toolkit/contrib/telnet/server.py
@@ -0,0 +1,427 @@
+"""
+Telnet server.
+"""
+from __future__ import annotations
+
+import asyncio
+import contextvars
+import socket
+from asyncio import get_running_loop
+from typing import Any, Callable, Coroutine, TextIO, cast
+
+from prompt_toolkit.application.current import create_app_session, get_app
+from prompt_toolkit.application.run_in_terminal import run_in_terminal
+from prompt_toolkit.data_structures import Size
+from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
+from prompt_toolkit.input import PipeInput, create_pipe_input
+from prompt_toolkit.output.vt100 import Vt100_Output
+from prompt_toolkit.renderer import print_formatted_text as print_formatted_text
+from prompt_toolkit.styles import BaseStyle, DummyStyle
+
+from .log import logger
+from .protocol import (
+ DO,
+ ECHO,
+ IAC,
+ LINEMODE,
+ MODE,
+ NAWS,
+ SB,
+ SE,
+ SEND,
+ SUPPRESS_GO_AHEAD,
+ TTYPE,
+ WILL,
+ TelnetProtocolParser,
+)
+
+__all__ = [
+ "TelnetServer",
+]
+
+
+def int2byte(number: int) -> bytes:
+ return bytes((number,))
+
+
+def _initialize_telnet(connection: socket.socket) -> None:
+ logger.info("Initializing telnet connection")
+
+ # Iac Do Linemode
+ connection.send(IAC + DO + LINEMODE)
+
+ # Suppress Go Ahead. (This seems important for Putty to do correct echoing.)
+ # This will allow bi-directional operation.
+ connection.send(IAC + WILL + SUPPRESS_GO_AHEAD)
+
+ # Iac sb
+ connection.send(IAC + SB + LINEMODE + MODE + int2byte(0) + IAC + SE)
+
+ # IAC Will Echo
+ connection.send(IAC + WILL + ECHO)
+
+ # Negotiate window size
+ connection.send(IAC + DO + NAWS)
+
+ # Negotiate terminal type
+ # Assume the client will accept the negotiation with `IAC + WILL + TTYPE`
+ connection.send(IAC + DO + TTYPE)
+
+ # We can then select the first terminal type supported by the client,
+ # which is generally the best type the client supports
+ # The client should reply with a `IAC + SB + TTYPE + IS + ttype + IAC + SE`
+ connection.send(IAC + SB + TTYPE + SEND + IAC + SE)
+
+
+class _ConnectionStdout:
+ """
+ Wrapper around socket which provides `write` and `flush` methods for the
+ Vt100_Output output.
+ """
+
+ def __init__(self, connection: socket.socket, encoding: str) -> None:
+ self._encoding = encoding
+ self._connection = connection
+ self._errors = "strict"
+ self._buffer: list[bytes] = []
+ self._closed = False
+
+ def write(self, data: str) -> None:
+ data = data.replace("\n", "\r\n")
+ self._buffer.append(data.encode(self._encoding, errors=self._errors))
+ self.flush()
+
+ def isatty(self) -> bool:
+ return True
+
+ def flush(self) -> None:
+ try:
+ if not self._closed:
+ self._connection.send(b"".join(self._buffer))
+ except OSError as e:
+ logger.warning("Couldn't send data over socket: %s" % e)
+
+ self._buffer = []
+
+ def close(self) -> None:
+ self._closed = True
+
+ @property
+ def encoding(self) -> str:
+ return self._encoding
+
+ @property
+ def errors(self) -> str:
+ return self._errors
+
+
+class TelnetConnection:
+ """
+ Class that represents one Telnet connection.
+ """
+
+ def __init__(
+ self,
+ conn: socket.socket,
+ addr: tuple[str, int],
+ interact: Callable[[TelnetConnection], Coroutine[Any, Any, None]],
+ server: TelnetServer,
+ encoding: str,
+ style: BaseStyle | None,
+ vt100_input: PipeInput,
+ enable_cpr: bool = True,
+ ) -> None:
+ self.conn = conn
+ self.addr = addr
+ self.interact = interact
+ self.server = server
+ self.encoding = encoding
+ self.style = style
+ self._closed = False
+ self._ready = asyncio.Event()
+ self.vt100_input = vt100_input
+ self.enable_cpr = enable_cpr
+ self.vt100_output: Vt100_Output | None = None
+
+ # Create "Output" object.
+ self.size = Size(rows=40, columns=79)
+
+ # Initialize.
+ _initialize_telnet(conn)
+
+ # Create output.
+ def get_size() -> Size:
+ return self.size
+
+ self.stdout = cast(TextIO, _ConnectionStdout(conn, encoding=encoding))
+
+ def data_received(data: bytes) -> None:
+ """TelnetProtocolParser 'data_received' callback"""
+ self.vt100_input.send_bytes(data)
+
+ def size_received(rows: int, columns: int) -> None:
+ """TelnetProtocolParser 'size_received' callback"""
+ self.size = Size(rows=rows, columns=columns)
+ if self.vt100_output is not None and self.context:
+ self.context.run(lambda: get_app()._on_resize())
+
+ def ttype_received(ttype: str) -> None:
+ """TelnetProtocolParser 'ttype_received' callback"""
+ self.vt100_output = Vt100_Output(
+ self.stdout, get_size, term=ttype, enable_cpr=enable_cpr
+ )
+ self._ready.set()
+
+ self.parser = TelnetProtocolParser(data_received, size_received, ttype_received)
+ self.context: contextvars.Context | None = None
+
+ async def run_application(self) -> None:
+ """
+ Run application.
+ """
+
+ def handle_incoming_data() -> None:
+ data = self.conn.recv(1024)
+ if data:
+ self.feed(data)
+ else:
+ # Connection closed by client.
+ logger.info("Connection closed by client. {!r} {!r}".format(*self.addr))
+ self.close()
+
+ # Add reader.
+ loop = get_running_loop()
+ loop.add_reader(self.conn, handle_incoming_data)
+
+ try:
+ # Wait for v100_output to be properly instantiated
+ await self._ready.wait()
+ with create_app_session(input=self.vt100_input, output=self.vt100_output):
+ self.context = contextvars.copy_context()
+ await self.interact(self)
+ finally:
+ self.close()
+
+ def feed(self, data: bytes) -> None:
+ """
+ Handler for incoming data. (Called by TelnetServer.)
+ """
+ self.parser.feed(data)
+
+ def close(self) -> None:
+ """
+ Closed by client.
+ """
+ if not self._closed:
+ self._closed = True
+
+ self.vt100_input.close()
+ get_running_loop().remove_reader(self.conn)
+ self.conn.close()
+ self.stdout.close()
+
+ def send(self, formatted_text: AnyFormattedText) -> None:
+ """
+ Send text to the client.
+ """
+ if self.vt100_output is None:
+ return
+ formatted_text = to_formatted_text(formatted_text)
+ print_formatted_text(
+ self.vt100_output, formatted_text, self.style or DummyStyle()
+ )
+
+ def send_above_prompt(self, formatted_text: AnyFormattedText) -> None:
+ """
+ Send text to the client.
+ This is asynchronous, returns a `Future`.
+ """
+ formatted_text = to_formatted_text(formatted_text)
+ return self._run_in_terminal(lambda: self.send(formatted_text))
+
+ def _run_in_terminal(self, func: Callable[[], None]) -> None:
+ # Make sure that when an application was active for this connection,
+ # that we print the text above the application.
+ if self.context:
+ self.context.run(run_in_terminal, func) # type: ignore
+ else:
+ raise RuntimeError("Called _run_in_terminal outside `run_application`.")
+
+ def erase_screen(self) -> None:
+ """
+ Erase the screen and move the cursor to the top.
+ """
+ if self.vt100_output is None:
+ return
+ self.vt100_output.erase_screen()
+ self.vt100_output.cursor_goto(0, 0)
+ self.vt100_output.flush()
+
+
+async def _dummy_interact(connection: TelnetConnection) -> None:
+ pass
+
+
+class TelnetServer:
+ """
+ Telnet server implementation.
+
+ Example::
+
+ async def interact(connection):
+ connection.send("Welcome")
+ session = PromptSession()
+ result = await session.prompt_async(message="Say something: ")
+ connection.send(f"You said: {result}\n")
+
+ async def main():
+ server = TelnetServer(interact=interact, port=2323)
+ await server.run()
+ """
+
+ def __init__(
+ self,
+ host: str = "127.0.0.1",
+ port: int = 23,
+ interact: Callable[
+ [TelnetConnection], Coroutine[Any, Any, None]
+ ] = _dummy_interact,
+ encoding: str = "utf-8",
+ style: BaseStyle | None = None,
+ enable_cpr: bool = True,
+ ) -> None:
+ self.host = host
+ self.port = port
+ self.interact = interact
+ self.encoding = encoding
+ self.style = style
+ self.enable_cpr = enable_cpr
+
+ self._run_task: asyncio.Task[None] | None = None
+ self._application_tasks: list[asyncio.Task[None]] = []
+
+ self.connections: set[TelnetConnection] = set()
+
+ @classmethod
+ def _create_socket(cls, host: str, port: int) -> socket.socket:
+ # Create and bind socket
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind((host, port))
+
+ s.listen(4)
+ return s
+
+ async def run(self, ready_cb: Callable[[], None] | None = None) -> None:
+ """
+ Run the telnet server, until this gets cancelled.
+
+ :param ready_cb: Callback that will be called at the point that we're
+ actually listening.
+ """
+ socket = self._create_socket(self.host, self.port)
+ logger.info(
+ "Listening for telnet connections on %s port %r", self.host, self.port
+ )
+
+ get_running_loop().add_reader(socket, lambda: self._accept(socket))
+
+ if ready_cb:
+ ready_cb()
+
+ try:
+ # Run forever, until cancelled.
+ await asyncio.Future()
+ finally:
+ get_running_loop().remove_reader(socket)
+ socket.close()
+
+ # Wait for all applications to finish.
+ for t in self._application_tasks:
+ t.cancel()
+
+ # (This is similar to
+ # `Application.cancel_and_wait_for_background_tasks`. We wait for the
+ # background tasks to complete, but don't propagate exceptions, because
+ # we can't use `ExceptionGroup` yet.)
+ if len(self._application_tasks) > 0:
+ await asyncio.wait(
+ self._application_tasks,
+ timeout=None,
+ return_when=asyncio.ALL_COMPLETED,
+ )
+
+ def start(self) -> None:
+ """
+ Deprecated: Use `.run()` instead.
+
+ Start the telnet server (stop by calling and awaiting `stop()`).
+ """
+ if self._run_task is not None:
+ # Already running.
+ return
+
+ self._run_task = get_running_loop().create_task(self.run())
+
+ async def stop(self) -> None:
+ """
+ Deprecated: Use `.run()` instead.
+
+ Stop a telnet server that was started using `.start()` and wait for the
+ cancellation to complete.
+ """
+ if self._run_task is not None:
+ self._run_task.cancel()
+ try:
+ await self._run_task
+ except asyncio.CancelledError:
+ pass
+
+ def _accept(self, listen_socket: socket.socket) -> None:
+ """
+ Accept new incoming connection.
+ """
+ conn, addr = listen_socket.accept()
+ logger.info("New connection %r %r", *addr)
+
+ # Run application for this connection.
+ async def run() -> None:
+ try:
+ with create_pipe_input() as vt100_input:
+ connection = TelnetConnection(
+ conn,
+ addr,
+ self.interact,
+ self,
+ encoding=self.encoding,
+ style=self.style,
+ vt100_input=vt100_input,
+ enable_cpr=self.enable_cpr,
+ )
+ self.connections.add(connection)
+
+ logger.info("Starting interaction %r %r", *addr)
+ try:
+ await connection.run_application()
+ finally:
+ self.connections.remove(connection)
+ logger.info("Stopping interaction %r %r", *addr)
+ except EOFError:
+ # Happens either when the connection is closed by the client
+ # (e.g., when the user types 'control-]', then 'quit' in the
+ # telnet client) or when the user types control-d in a prompt
+ # and this is not handled by the interact function.
+ logger.info("Unhandled EOFError in telnet application.")
+ except KeyboardInterrupt:
+ # Unhandled control-c propagated by a prompt.
+ logger.info("Unhandled KeyboardInterrupt in telnet application.")
+ except BaseException as e:
+ print("Got %s" % type(e).__name__, e)
+ import traceback
+
+ traceback.print_exc()
+ finally:
+ self._application_tasks.remove(task)
+
+ task = get_running_loop().create_task(run())
+ self._application_tasks.append(task)