summaryrefslogtreecommitdiffstats
path: root/python/qemu/aqmp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/qemu/aqmp/protocol.py')
-rw-r--r--python/qemu/aqmp/protocol.py1048
1 files changed, 0 insertions, 1048 deletions
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py
deleted file mode 100644
index 36fae57f27..0000000000
--- a/python/qemu/aqmp/protocol.py
+++ /dev/null
@@ -1,1048 +0,0 @@
-"""
-Generic Asynchronous Message-based Protocol Support
-
-This module provides a generic framework for sending and receiving
-messages over an asyncio stream. `AsyncProtocol` is an abstract class
-that implements the core mechanisms of a simple send/receive protocol,
-and is designed to be extended.
-
-In this package, it is used as the implementation for the `QMPClient`
-class.
-"""
-
-# It's all the docstrings ... ! It's long for a good reason ^_^;
-# pylint: disable=too-many-lines
-
-import asyncio
-from asyncio import StreamReader, StreamWriter
-from enum import Enum
-from functools import wraps
-import logging
-from ssl import SSLContext
-from typing import (
- Any,
- Awaitable,
- Callable,
- Generic,
- List,
- Optional,
- Tuple,
- TypeVar,
- Union,
- cast,
-)
-
-from .error import QMPError
-from .util import (
- bottom_half,
- create_task,
- exception_summary,
- flush,
- is_closing,
- pretty_traceback,
- upper_half,
- wait_closed,
-)
-
-
-T = TypeVar('T')
-_U = TypeVar('_U')
-_TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None``
-
-InternetAddrT = Tuple[str, int]
-UnixAddrT = str
-SocketAddrT = Union[UnixAddrT, InternetAddrT]
-
-
-class Runstate(Enum):
- """Protocol session runstate."""
-
- #: Fully quiesced and disconnected.
- IDLE = 0
- #: In the process of connecting or establishing a session.
- CONNECTING = 1
- #: Fully connected and active session.
- RUNNING = 2
- #: In the process of disconnecting.
- #: Runstate may be returned to `IDLE` by calling `disconnect()`.
- DISCONNECTING = 3
-
-
-class ConnectError(QMPError):
- """
- Raised when the initial connection process has failed.
-
- This Exception always wraps a "root cause" exception that can be
- interrogated for additional information.
-
- :param error_message: Human-readable string describing the error.
- :param exc: The root-cause exception.
- """
- def __init__(self, error_message: str, exc: Exception):
- super().__init__(error_message)
- #: Human-readable error string
- self.error_message: str = error_message
- #: Wrapped root cause exception
- self.exc: Exception = exc
-
- def __str__(self) -> str:
- cause = str(self.exc)
- if not cause:
- # If there's no error string, use the exception name.
- cause = exception_summary(self.exc)
- return f"{self.error_message}: {cause}"
-
-
-class StateError(QMPError):
- """
- An API command (connect, execute, etc) was issued at an inappropriate time.
-
- This error is raised when a command like
- :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
- time.
-
- :param error_message: Human-readable string describing the state violation.
- :param state: The actual `Runstate` seen at the time of the violation.
- :param required: The `Runstate` required to process this command.
- """
- def __init__(self, error_message: str,
- state: Runstate, required: Runstate):
- super().__init__(error_message)
- self.error_message = error_message
- self.state = state
- self.required = required
-
-
-F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name
-
-
-# Don't Panic.
-def require(required_state: Runstate) -> Callable[[F], F]:
- """
- Decorator: protect a method so it can only be run in a certain `Runstate`.
-
- :param required_state: The `Runstate` required to invoke this method.
- :raise StateError: When the required `Runstate` is not met.
- """
- def _decorator(func: F) -> F:
- # _decorator is the decorator that is built by calling the
- # require() decorator factory; e.g.:
- #
- # @require(Runstate.IDLE) def foo(): ...
- # will replace 'foo' with the result of '_decorator(foo)'.
-
- @wraps(func)
- def _wrapper(proto: 'AsyncProtocol[Any]',
- *args: Any, **kwargs: Any) -> Any:
- # _wrapper is the function that gets executed prior to the
- # decorated method.
-
- name = type(proto).__name__
-
- if proto.runstate != required_state:
- if proto.runstate == Runstate.CONNECTING:
- emsg = f"{name} is currently connecting."
- elif proto.runstate == Runstate.DISCONNECTING:
- emsg = (f"{name} is disconnecting."
- " Call disconnect() to return to IDLE state.")
- elif proto.runstate == Runstate.RUNNING:
- emsg = f"{name} is already connected and running."
- elif proto.runstate == Runstate.IDLE:
- emsg = f"{name} is disconnected and idle."
- else:
- assert False
- raise StateError(emsg, proto.runstate, required_state)
- # No StateError, so call the wrapped method.
- return func(proto, *args, **kwargs)
-
- # Return the decorated method;
- # Transforming Func to Decorated[Func].
- return cast(F, _wrapper)
-
- # Return the decorator instance from the decorator factory. Phew!
- return _decorator
-
-
-class AsyncProtocol(Generic[T]):
- """
- AsyncProtocol implements a generic async message-based protocol.
-
- This protocol assumes the basic unit of information transfer between
- client and server is a "message", the details of which are left up
- to the implementation. It assumes the sending and receiving of these
- messages is full-duplex and not necessarily correlated; i.e. it
- supports asynchronous inbound messages.
-
- It is designed to be extended by a specific protocol which provides
- the implementations for how to read and send messages. These must be
- defined in `_do_recv()` and `_do_send()`, respectively.
-
- Other callbacks have a default implementation, but are intended to be
- either extended or overridden:
-
- - `_establish_session`:
- The base implementation starts the reader/writer tasks.
- A protocol implementation can override this call, inserting
- actions to be taken prior to starting the reader/writer tasks
- before the super() call; actions needing to occur afterwards
- can be written after the super() call.
- - `_on_message`:
- Actions to be performed when a message is received.
- - `_cb_outbound`:
- Logging/Filtering hook for all outbound messages.
- - `_cb_inbound`:
- Logging/Filtering hook for all inbound messages.
- This hook runs *before* `_on_message()`.
-
- :param name:
- Name used for logging messages, if any. By default, messages
- will log to 'qemu.aqmp.protocol', but each individual connection
- can be given its own logger by giving it a name; messages will
- then log to 'qemu.aqmp.protocol.${name}'.
- """
- # pylint: disable=too-many-instance-attributes
-
- #: Logger object for debugging messages from this connection.
- logger = logging.getLogger(__name__)
-
- # Maximum allowable size of read buffer
- _limit = (64 * 1024)
-
- # -------------------------
- # Section: Public interface
- # -------------------------
-
- def __init__(self, name: Optional[str] = None) -> None:
- #: The nickname for this connection, if any.
- self.name: Optional[str] = name
- if self.name is not None:
- self.logger = self.logger.getChild(self.name)
-
- # stream I/O
- self._reader: Optional[StreamReader] = None
- self._writer: Optional[StreamWriter] = None
-
- # Outbound Message queue
- self._outgoing: asyncio.Queue[T]
-
- # Special, long-running tasks:
- self._reader_task: Optional[asyncio.Future[None]] = None
- self._writer_task: Optional[asyncio.Future[None]] = None
-
- # Aggregate of the above two tasks, used for Exception management.
- self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None
-
- #: Disconnect task. The disconnect implementation runs in a task
- #: so that asynchronous disconnects (initiated by the
- #: reader/writer) are allowed to wait for the reader/writers to
- #: exit.
- self._dc_task: Optional[asyncio.Future[None]] = None
-
- self._runstate = Runstate.IDLE
- self._runstate_changed: Optional[asyncio.Event] = None
-
- # Server state for start_server() and _incoming()
- self._server: Optional[asyncio.AbstractServer] = None
- self._accepted: Optional[asyncio.Event] = None
-
- def __repr__(self) -> str:
- cls_name = type(self).__name__
- tokens = []
- if self.name is not None:
- tokens.append(f"name={self.name!r}")
- tokens.append(f"runstate={self.runstate.name}")
- return f"<{cls_name} {' '.join(tokens)}>"
-
- @property # @upper_half
- def runstate(self) -> Runstate:
- """The current `Runstate` of the connection."""
- return self._runstate
-
- @upper_half
- async def runstate_changed(self) -> Runstate:
- """
- Wait for the `runstate` to change, then return that runstate.
- """
- await self._runstate_event.wait()
- return self.runstate
-
- @upper_half
- @require(Runstate.IDLE)
- async def start_server_and_accept(
- self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None
- ) -> None:
- """
- Accept a connection and begin processing message queues.
-
- If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
- This method is precisely equivalent to calling `start_server()`
- followed by `accept()`.
-
- :param address:
- Address to listen on; UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
-
- :raise StateError: When the `Runstate` is not `IDLE`.
- :raise ConnectError:
- When a connection or session cannot be established.
-
- This exception will wrap a more concrete one. In most cases,
- the wrapped exception will be `OSError` or `EOFError`. If a
- protocol-level failure occurs while establishing a new
- session, the wrapped error may also be an `QMPError`.
- """
- await self.start_server(address, ssl)
- await self.accept()
- assert self.runstate == Runstate.RUNNING
-
- @upper_half
- @require(Runstate.IDLE)
- async def start_server(self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None) -> None:
- """
- Start listening for an incoming connection, but do not wait for a peer.
-
- This method starts listening for an incoming connection, but
- does not block waiting for a peer. This call will return
- immediately after binding and listening on a socket. A later
- call to `accept()` must be made in order to finalize the
- incoming connection.
-
- :param address:
- Address to listen on; UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
-
- :raise StateError: When the `Runstate` is not `IDLE`.
- :raise ConnectError:
- When the server could not start listening on this address.
-
- This exception will wrap a more concrete one. In most cases,
- the wrapped exception will be `OSError`.
- """
- await self._session_guard(
- self._do_start_server(address, ssl),
- 'Failed to establish connection')
- assert self.runstate == Runstate.CONNECTING
-
- @upper_half
- @require(Runstate.CONNECTING)
- async def accept(self) -> None:
- """
- Accept an incoming connection and begin processing message queues.
-
- If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
-
- :raise StateError: When the `Runstate` is not `CONNECTING`.
- :raise QMPError: When `start_server()` was not called yet.
- :raise ConnectError:
- When a connection or session cannot be established.
-
- This exception will wrap a more concrete one. In most cases,
- the wrapped exception will be `OSError` or `EOFError`. If a
- protocol-level failure occurs while establishing a new
- session, the wrapped error may also be an `QMPError`.
- """
- if self._accepted is None:
- raise QMPError("Cannot call accept() before start_server().")
- await self._session_guard(
- self._do_accept(),
- 'Failed to establish connection')
- await self._session_guard(
- self._establish_session(),
- 'Failed to establish session')
- assert self.runstate == Runstate.RUNNING
-
- @upper_half
- @require(Runstate.IDLE)
- async def connect(self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None) -> None:
- """
- Connect to the server and begin processing message queues.
-
- If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
-
- :param address:
- Address to connect to; UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
-
- :raise StateError: When the `Runstate` is not `IDLE`.
- :raise ConnectError:
- When a connection or session cannot be established.
-
- This exception will wrap a more concrete one. In most cases,
- the wrapped exception will be `OSError` or `EOFError`. If a
- protocol-level failure occurs while establishing a new
- session, the wrapped error may also be an `QMPError`.
- """
- await self._session_guard(
- self._do_connect(address, ssl),
- 'Failed to establish connection')
- await self._session_guard(
- self._establish_session(),
- 'Failed to establish session')
- assert self.runstate == Runstate.RUNNING
-
- @upper_half
- async def disconnect(self) -> None:
- """
- Disconnect and wait for all tasks to fully stop.
-
- If there was an exception that caused the reader/writers to
- terminate prematurely, it will be raised here.
-
- :raise Exception: When the reader or writer terminate unexpectedly.
- """
- self.logger.debug("disconnect() called.")
- self._schedule_disconnect()
- await self._wait_disconnect()
-
- # --------------------------
- # Section: Session machinery
- # --------------------------
-
- async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
- """
- Async guard function used to roll back to `IDLE` on any error.
-
- On any Exception, the state machine will be reset back to
- `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
- `BaseException` events will be left alone (This includes
- asyncio.CancelledError, even prior to Python 3.8).
-
- :param error_message:
- Human-readable string describing what connection phase failed.
-
- :raise BaseException:
- When `BaseException` occurs in the guarded block.
- :raise ConnectError:
- When any other error is encountered in the guarded block.
- """
- # Note: After Python 3.6 support is removed, this should be an
- # @asynccontextmanager instead of accepting a callback.
- try:
- await coro
- except BaseException as err:
- self.logger.error("%s: %s", emsg, exception_summary(err))
- self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
- try:
- # Reset the runstate back to IDLE.
- await self.disconnect()
- except:
- # We don't expect any Exceptions from the disconnect function
- # here, because we failed to connect in the first place.
- # The disconnect() function is intended to perform
- # only cannot-fail cleanup here, but you never know.
- emsg = (
- "Unexpected bottom half exception. "
- "This is a bug in the QMP library. "
- "Please report it to <qemu-devel@nongnu.org> and "
- "CC: John Snow <jsnow@redhat.com>."
- )
- self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
- raise
-
- # CancelledError is an Exception with special semantic meaning;
- # We do NOT want to wrap it up under ConnectError.
- # NB: CancelledError is not a BaseException before Python 3.8
- if isinstance(err, asyncio.CancelledError):
- raise
-
- # Any other kind of error can be treated as some kind of connection
- # failure broadly. Inspect the 'exc' field to explore the root
- # cause in greater detail.
- if isinstance(err, Exception):
- raise ConnectError(emsg, err) from err
-
- # Raise BaseExceptions un-wrapped, they're more important.
- raise
-
- @property
- def _runstate_event(self) -> asyncio.Event:
- # asyncio.Event() objects should not be created prior to entrance into
- # an event loop, so we can ensure we create it in the correct context.
- # Create it on-demand *only* at the behest of an 'async def' method.
- if not self._runstate_changed:
- self._runstate_changed = asyncio.Event()
- return self._runstate_changed
-
- @upper_half
- @bottom_half
- def _set_state(self, state: Runstate) -> None:
- """
- Change the `Runstate` of the protocol connection.
-
- Signals the `runstate_changed` event.
- """
- if state == self._runstate:
- return
-
- self.logger.debug("Transitioning from '%s' to '%s'.",
- str(self._runstate), str(state))
- self._runstate = state
- self._runstate_event.set()
- self._runstate_event.clear()
-
- @bottom_half
- async def _stop_server(self) -> None:
- """
- Stop listening for / accepting new incoming connections.
- """
- if self._server is None:
- return
-
- try:
- self.logger.debug("Stopping server.")
- self._server.close()
- await self._server.wait_closed()
- self.logger.debug("Server stopped.")
- finally:
- self._server = None
-
- @bottom_half # However, it does not run from the R/W tasks.
- async def _incoming(self,
- reader: asyncio.StreamReader,
- writer: asyncio.StreamWriter) -> None:
- """
- Accept an incoming connection and signal the upper_half.
-
- This method does the minimum necessary to accept a single
- incoming connection. It signals back to the upper_half ASAP so
- that any errors during session initialization can occur
- naturally in the caller's stack.
-
- :param reader: Incoming `asyncio.StreamReader`
- :param writer: Incoming `asyncio.StreamWriter`
- """
- peer = writer.get_extra_info('peername', 'Unknown peer')
- self.logger.debug("Incoming connection from %s", peer)
-
- if self._reader or self._writer:
- # Sadly, we can have more than one pending connection
- # because of https://bugs.python.org/issue46715
- # Close any extra connections we don't actually want.
- self.logger.warning("Extraneous connection inadvertently accepted")
- writer.close()
- return
-
- # A connection has been accepted; stop listening for new ones.
- assert self._accepted is not None
- await self._stop_server()
- self._reader, self._writer = (reader, writer)
- self._accepted.set()
-
- @upper_half
- async def _do_start_server(self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None) -> None:
- """
- Start listening for an incoming connection, but do not wait for a peer.
-
- This method starts listening for an incoming connection, but does not
- block waiting for a peer. This call will return immediately after
- binding and listening to a socket. A later call to accept() must be
- made in order to finalize the incoming connection.
-
- :param address:
- Address to listen on; UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
-
- :raise OSError: For stream-related errors.
- """
- assert self.runstate == Runstate.IDLE
- self._set_state(Runstate.CONNECTING)
-
- self.logger.debug("Awaiting connection on %s ...", address)
- self._accepted = asyncio.Event()
-
- if isinstance(address, tuple):
- coro = asyncio.start_server(
- self._incoming,
- host=address[0],
- port=address[1],
- ssl=ssl,
- backlog=1,
- limit=self._limit,
- )
- else:
- coro = asyncio.start_unix_server(
- self._incoming,
- path=address,
- ssl=ssl,
- backlog=1,
- limit=self._limit,
- )
-
- # Allow runstate watchers to witness 'CONNECTING' state; some
- # failures in the streaming layer are synchronous and will not
- # otherwise yield.
- await asyncio.sleep(0)
-
- # This will start the server (bind(2), listen(2)). It will also
- # call accept(2) if we yield, but we don't block on that here.
- self._server = await coro
- self.logger.debug("Server listening on %s", address)
-
- @upper_half
- async def _do_accept(self) -> None:
- """
- Wait for and accept an incoming connection.
-
- Requires that we have not yet accepted an incoming connection
- from the upper_half, but it's OK if the server is no longer
- running because the bottom_half has already accepted the
- connection.
- """
- assert self._accepted is not None
- await self._accepted.wait()
- assert self._server is None
- self._accepted = None
-
- self.logger.debug("Connection accepted.")
-
- @upper_half
- async def _do_connect(self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None) -> None:
- """
- Acting as the transport client, initiate a connection to a server.
-
- :param address:
- Address to connect to; UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
-
- :raise OSError: For stream-related errors.
- """
- assert self.runstate == Runstate.IDLE
- self._set_state(Runstate.CONNECTING)
-
- # Allow runstate watchers to witness 'CONNECTING' state; some
- # failures in the streaming layer are synchronous and will not
- # otherwise yield.
- await asyncio.sleep(0)
-
- self.logger.debug("Connecting to %s ...", address)
-
- if isinstance(address, tuple):
- connect = asyncio.open_connection(
- address[0],
- address[1],
- ssl=ssl,
- limit=self._limit,
- )
- else:
- connect = asyncio.open_unix_connection(
- path=address,
- ssl=ssl,
- limit=self._limit,
- )
- self._reader, self._writer = await connect
-
- self.logger.debug("Connected.")
-
- @upper_half
- async def _establish_session(self) -> None:
- """
- Establish a new session.
-
- Starts the readers/writer tasks; subclasses may perform their
- own negotiations here. The Runstate will be RUNNING upon
- successful conclusion.
- """
- assert self.runstate == Runstate.CONNECTING
-
- self._outgoing = asyncio.Queue()
-
- reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader')
- writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer')
-
- self._reader_task = create_task(reader_coro)
- self._writer_task = create_task(writer_coro)
-
- self._bh_tasks = asyncio.gather(
- self._reader_task,
- self._writer_task,
- )
-
- self._set_state(Runstate.RUNNING)
- await asyncio.sleep(0) # Allow runstate_event to process
-
- @upper_half
- @bottom_half
- def _schedule_disconnect(self) -> None:
- """
- Initiate a disconnect; idempotent.
-
- This method is used both in the upper-half as a direct
- consequence of `disconnect()`, and in the bottom-half in the
- case of unhandled exceptions in the reader/writer tasks.
-
- It can be invoked no matter what the `runstate` is.
- """
- if not self._dc_task:
- self._set_state(Runstate.DISCONNECTING)
- self.logger.debug("Scheduling disconnect.")
- self._dc_task = create_task(self._bh_disconnect())
-
- @upper_half
- async def _wait_disconnect(self) -> None:
- """
- Waits for a previously scheduled disconnect to finish.
-
- This method will gather any bottom half exceptions and re-raise
- the one that occurred first; presuming it to be the root cause
- of any subsequent Exceptions. It is intended to be used in the
- upper half of the call chain.
-
- :raise Exception:
- Arbitrary exception re-raised on behalf of the reader/writer.
- """
- assert self.runstate == Runstate.DISCONNECTING
- assert self._dc_task
-
- aws: List[Awaitable[object]] = [self._dc_task]
- if self._bh_tasks:
- aws.insert(0, self._bh_tasks)
- all_defined_tasks = asyncio.gather(*aws)
-
- # Ensure disconnect is done; Exception (if any) is not raised here:
- await asyncio.wait((self._dc_task,))
-
- try:
- await all_defined_tasks # Raise Exceptions from the bottom half.
- finally:
- self._cleanup()
- self._set_state(Runstate.IDLE)
-
- @upper_half
- def _cleanup(self) -> None:
- """
- Fully reset this object to a clean state and return to `IDLE`.
- """
- def _paranoid_task_erase(task: Optional['asyncio.Future[_U]']
- ) -> Optional['asyncio.Future[_U]']:
- # Help to erase a task, ENSURING it is fully quiesced first.
- assert (task is None) or task.done()
- return None if (task and task.done()) else task
-
- assert self.runstate == Runstate.DISCONNECTING
- self._dc_task = _paranoid_task_erase(self._dc_task)
- self._reader_task = _paranoid_task_erase(self._reader_task)
- self._writer_task = _paranoid_task_erase(self._writer_task)
- self._bh_tasks = _paranoid_task_erase(self._bh_tasks)
-
- self._reader = None
- self._writer = None
- self._accepted = None
-
- # NB: _runstate_changed cannot be cleared because we still need it to
- # send the final runstate changed event ...!
-
- # ----------------------------
- # Section: Bottom Half methods
- # ----------------------------
-
- @bottom_half
- async def _bh_disconnect(self) -> None:
- """
- Disconnect and cancel all outstanding tasks.
-
- It is designed to be called from its task context,
- :py:obj:`~AsyncProtocol._dc_task`. By running in its own task,
- it is free to wait on any pending actions that may still need to
- occur in either the reader or writer tasks.
- """
- assert self.runstate == Runstate.DISCONNECTING
-
- def _done(task: Optional['asyncio.Future[Any]']) -> bool:
- return task is not None and task.done()
-
- # If the server is running, stop it.
- await self._stop_server()
-
- # Are we already in an error pathway? If either of the tasks are
- # already done, or if we have no tasks but a reader/writer; we
- # must be.
- #
- # NB: We can't use _bh_tasks to check for premature task
- # completion, because it may not yet have had a chance to run
- # and gather itself.
- tasks = tuple(filter(None, (self._writer_task, self._reader_task)))
- error_pathway = _done(self._reader_task) or _done(self._writer_task)
- if not tasks:
- error_pathway |= bool(self._reader) or bool(self._writer)
-
- try:
- # Try to flush the writer, if possible.
- # This *may* cause an error and force us over into the error path.
- if not error_pathway:
- await self._bh_flush_writer()
- except BaseException as err:
- error_pathway = True
- emsg = "Failed to flush the writer"
- self.logger.error("%s: %s", emsg, exception_summary(err))
- self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
- raise
- finally:
- # Cancel any still-running tasks (Won't raise):
- if self._writer_task is not None and not self._writer_task.done():
- self.logger.debug("Cancelling writer task.")
- self._writer_task.cancel()
- if self._reader_task is not None and not self._reader_task.done():
- self.logger.debug("Cancelling reader task.")
- self._reader_task.cancel()
-
- # Close out the tasks entirely (Won't raise):
- if tasks:
- self.logger.debug("Waiting for tasks to complete ...")
- await asyncio.wait(tasks)
-
- # Lastly, close the stream itself. (*May raise*!):
- await self._bh_close_stream(error_pathway)
- self.logger.debug("Disconnected.")
-
- @bottom_half
- async def _bh_flush_writer(self) -> None:
- if not self._writer_task:
- return
-
- self.logger.debug("Draining the outbound queue ...")
- await self._outgoing.join()
- if self._writer is not None:
- self.logger.debug("Flushing the StreamWriter ...")
- await flush(self._writer)
-
- @bottom_half
- async def _bh_close_stream(self, error_pathway: bool = False) -> None:
- # NB: Closing the writer also implcitly closes the reader.
- if not self._writer:
- return
-
- if not is_closing(self._writer):
- self.logger.debug("Closing StreamWriter.")
- self._writer.close()
-
- self.logger.debug("Waiting for StreamWriter to close ...")
- try:
- await wait_closed(self._writer)
- except Exception: # pylint: disable=broad-except
- # It's hard to tell if the Stream is already closed or
- # not. Even if one of the tasks has failed, it may have
- # failed for a higher-layered protocol reason. The
- # stream could still be open and perfectly fine.
- # I don't know how to discern its health here.
-
- if error_pathway:
- # We already know that *something* went wrong. Let's
- # just trust that the Exception we already have is the
- # better one to present to the user, even if we don't
- # genuinely *know* the relationship between the two.
- self.logger.debug(
- "Discarding Exception from wait_closed:\n%s\n",
- pretty_traceback(),
- )
- else:
- # Oops, this is a brand-new error!
- raise
- finally:
- self.logger.debug("StreamWriter closed.")
-
- @bottom_half
- async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None:
- """
- Run one of the bottom-half methods in a loop forever.
-
- If the bottom half ever raises any exception, schedule a
- disconnect that will terminate the entire loop.
-
- :param async_fn: The bottom-half method to run in a loop.
- :param name: The name of this task, used for logging.
- """
- try:
- while True:
- await async_fn()
- except asyncio.CancelledError:
- # We have been cancelled by _bh_disconnect, exit gracefully.
- self.logger.debug("Task.%s: cancelled.", name)
- return
- except BaseException as err:
- self.logger.log(
- logging.INFO if isinstance(err, EOFError) else logging.ERROR,
- "Task.%s: %s",
- name, exception_summary(err)
- )
- self.logger.debug("Task.%s: failure:\n%s\n",
- name, pretty_traceback())
- self._schedule_disconnect()
- raise
- finally:
- self.logger.debug("Task.%s: exiting.", name)
-
- @bottom_half
- async def _bh_send_message(self) -> None:
- """
- Wait for an outgoing message, then send it.
-
- Designed to be run in `_bh_loop_forever()`.
- """
- msg = await self._outgoing.get()
- try:
- await self._send(msg)
- finally:
- self._outgoing.task_done()
-
- @bottom_half
- async def _bh_recv_message(self) -> None:
- """
- Wait for an incoming message and call `_on_message` to route it.
-
- Designed to be run in `_bh_loop_forever()`.
- """
- msg = await self._recv()
- await self._on_message(msg)
-
- # --------------------
- # Section: Message I/O
- # --------------------
-
- @upper_half
- @bottom_half
- def _cb_outbound(self, msg: T) -> T:
- """
- Callback: outbound message hook.
-
- This is intended for subclasses to be able to add arbitrary
- hooks to filter or manipulate outgoing messages. The base
- implementation does nothing but log the message without any
- manipulation of the message.
-
- :param msg: raw outbound message
- :return: final outbound message
- """
- self.logger.debug("--> %s", str(msg))
- return msg
-
- @upper_half
- @bottom_half
- def _cb_inbound(self, msg: T) -> T:
- """
- Callback: inbound message hook.
-
- This is intended for subclasses to be able to add arbitrary
- hooks to filter or manipulate incoming messages. The base
- implementation does nothing but log the message without any
- manipulation of the message.
-
- This method does not "handle" incoming messages; it is a filter.
- The actual "endpoint" for incoming messages is `_on_message()`.
-
- :param msg: raw inbound message
- :return: processed inbound message
- """
- self.logger.debug("<-- %s", str(msg))
- return msg
-
- @upper_half
- @bottom_half
- async def _readline(self) -> bytes:
- """
- Wait for a newline from the incoming reader.
-
- This method is provided as a convenience for upper-layer
- protocols, as many are line-based.
-
- This method *may* return a sequence of bytes without a trailing
- newline if EOF occurs, but *some* bytes were received. In this
- case, the next call will raise `EOFError`. It is assumed that
- the layer 5 protocol will decide if there is anything meaningful
- to be done with a partial message.
-
- :raise OSError: For stream-related errors.
- :raise EOFError:
- If the reader stream is at EOF and there are no bytes to return.
- :return: bytes, including the newline.
- """
- assert self._reader is not None
- msg_bytes = await self._reader.readline()
-
- if not msg_bytes:
- if self._reader.at_eof():
- raise EOFError
-
- return msg_bytes
-
- @upper_half
- @bottom_half
- async def _do_recv(self) -> T:
- """
- Abstract: Read from the stream and return a message.
-
- Very low-level; intended to only be called by `_recv()`.
- """
- raise NotImplementedError
-
- @upper_half
- @bottom_half
- async def _recv(self) -> T:
- """
- Read an arbitrary protocol message.
-
- .. warning::
- This method is intended primarily for `_bh_recv_message()`
- to use in an asynchronous task loop. Using it outside of
- this loop will "steal" messages from the normal routing
- mechanism. It is safe to use prior to `_establish_session()`,
- but should not be used otherwise.
-
- This method uses `_do_recv()` to retrieve the raw message, and
- then transforms it using `_cb_inbound()`.
-
- :return: A single (filtered, processed) protocol message.
- """
- message = await self._do_recv()
- return self._cb_inbound(message)
-
- @upper_half
- @bottom_half
- def _do_send(self, msg: T) -> None:
- """
- Abstract: Write a message to the stream.
-
- Very low-level; intended to only be called by `_send()`.
- """
- raise NotImplementedError
-
- @upper_half
- @bottom_half
- async def _send(self, msg: T) -> None:
- """
- Send an arbitrary protocol message.
-
- This method will transform any outgoing messages according to
- `_cb_outbound()`.
-
- .. warning::
- Like `_recv()`, this method is intended to be called by
- the writer task loop that processes outgoing
- messages. Calling it directly may circumvent logic
- implemented by the caller meant to correlate outgoing and
- incoming messages.
-
- :raise OSError: For problems with the underlying stream.
- """
- msg = self._cb_outbound(msg)
- self._do_send(msg)
-
- @bottom_half
- async def _on_message(self, msg: T) -> None:
- """
- Called to handle the receipt of a new message.
-
- .. caution::
- This is executed from within the reader loop, so be advised
- that waiting on either the reader or writer task will lead
- to deadlock. Additionally, any unhandled exceptions will
- directly cause the loop to halt, so logic may be best-kept
- to a minimum if at all possible.
-
- :param msg: The incoming message, already logged/filtered.
- """
- # Nothing to do in the abstract case.