Source code for zonis.router

from __future__ import annotations

import asyncio
import json
import logging
import secrets
import traceback
import typing as t
from functools import partial

import websockets.client
from websockets.legacy.client import WebSocketClientProtocol

from zonis import (
from zonis.packet import IdentifyPacket, IdentifyDataPacket

log = logging.getLogger(__name__)

class PacketT(t.TypedDict):
    packet_id: str
    type: t.Literal["request", "response"]
    data: dict

[docs]class Router: """A router class for enabling two-way communication down a single pipe. This is built on the architecture ideals that if you want to send something down the pipe, then this class can return you the response just fine. However, if you wish to receive content for this class then you need to register an async method to handle that incoming data as this class doesn't have the context required to handle it. """ _EXIT_LITERAL: t.Final = "CLOSE_FUTURE" def __init__( self, identifier: str, websocket_connection: WebsocketProtocol | None, *, retry_on_exception: bool = True, max_retries: int = 3, default_client_websocket_factory: t.Type[WebsocketProtocol] = Websockets ): """ Parameters ---------- identifier: str How this connection identifies itself. websocket_connection: WebsocketProtocol An implementation of the websocket protocol. At this stage this only affects the Server, Client's will always use Websockets retry_on_exception: bool On failure, give it another go max_retries: int Max times to retry on exception Warnings -------- websocket_connection should **only ever** be None when called from Client. If you ever use this, ensure it is passed or set before first usage. """ self._websocket_connection: WebsocketProtocol | None = websocket_connection self._default_client_websocket_factory: t.Type[ WebsocketProtocol ] = default_client_websocket_factory self._retry_on_exception: bool = retry_on_exception self._max_retries: int = max_retries self._attempted_connection_count: int = 0 self.__url: str | None = None self.__is_open: bool = True self._ws_task: asyncio.Task | None = None self._receive_handler = None self.identifier: str = identifier self._router_id: str = secrets.token_bytes(32).hex() self._futures: dict[str, asyncio.Future] = {} self._item_queue: asyncio.Queue[ tuple[t.Literal["SEND_THEN_RECEIVE_RESPONSE"], dict] | tuple[t.Literal["SEND_RESPONSE"], dict] | tuple[t.Literal["CLOSE_FUTURE"]], ] = asyncio.Queue() self._connection_future: asyncio.Future | None = None self._block_future: asyncio.Future = asyncio.Future() self._receive_from_queue_task: asyncio.Task | None = None self._receive_from_ws_task: asyncio.Task | None = None
[docs] async def block_until_closed(self): """A blocking call which releases when the WS closes. Notes ----- This will block even if the underlying WS has yet to connect. """ await self._block_future
[docs] async def connect_client(self, url: str, idp: IdentifyDataPacket) -> None: """A blocking call which connects the underlying WS.""" log.debug("Initiating connection to %s with identity %s", url, idp) self._connection_future = asyncio.Future() self.__url = url self._ws_task = util.create_task( self._handle_ws(url, idp), router_id=self._router_id ) await self._connection_future
[docs] async def connect_server(self) -> None: self._ws_task = util.create_task(self._handle_pipe(), router_id=self._router_id)
[docs] async def close(self) -> None: """Close the underlying WS""" log.debug("Closing WS connection to %s", self.__url) self._item_queue.put_nowait((self._EXIT_LITERAL,)) # type: ignore self._check_for_congestion() # Let the event loop close the task by giving # back control for a no-op. Not sure if this # actually then kills the task or not but whatever await asyncio.sleep(0)
[docs] async def send(self, data: dict) -> asyncio.Future: """Send data down the pipe and receive a response. Parameters ---------- data: dict The packet to send down the pipe Returns ------- asyncio.Future This future will contain the returned data once it becomes available """ packet_id = secrets.token_bytes(16).hex() future: asyncio.Future = asyncio.get_running_loop().create_future() self._futures[packet_id] = future self._item_queue.put_nowait( ( "SEND_THEN_RECEIVE_RESPONSE", {"packet_id": packet_id, "type": "request", "data": data}, ) ) self._check_for_congestion() log.debug("Queued item to send as a request for packet %s", packet_id) return future
[docs] async def send_response(self, *, packet_id: str, data: dict | None = None): """Sends a response to a received request without waiting anything. Parameters ---------- packet_id: str The packet ID to respond to data: dict The data being responded with """ if data is None: data = {} packet: PacketT = {"type": "response", "packet_id": packet_id, "data": data} self._item_queue.put_nowait(("SEND_RESPONSE", packet)) self._check_for_congestion() log.debug("Queued item to send as a response for packet %s", packet_id)
[docs] def register_receiver(self, callback) -> Router: """Register the provided callback to handle packets with the request type. Parameters ---------- callback The callback to call with the packet data. """ if self._receive_handler is not None: log.debug("Router overriding existing receive handler") self._receive_handler = callback return self
async def _handle_client_identify(self, idp): future = await self.send( IdentifyPacket(identifier=self.identifier, type="IDENTIFY", data=idp) ) await future self._connection_future.set_result(None) async def _handle_ws(self, url, idp): """Handles the underlying websocket and distributed callouts. This method is called from clients. """ async for websocket in websockets.client.connect( url, ping_timeout=100000, ): try: websocket = t.cast(WebSocketClientProtocol, websocket) self._websocket_connection = self._default_client_websocket_factory( websocket ) # This is a task because sending is done within _handle_pipe util.create_task( self._handle_client_identify(idp), router_id=self._router_id ) log.debug("Websocket opened to %s", url) result = await self._handle_pipe() if result is not None and result == self._EXIT_LITERAL: break except Exception as error: log.critical("%s", "".join(traceback.format_exception(error))) finally: self._attempted_connection_count += 1 if ( self._retry_on_exception and self._max_retries < self._attempted_connection_count ): log.critical("Exceeded maximum websocket reconnections, exiting...") break self._block_future.set_result(None) def _ensure_tasks_exist(self): if self._receive_from_ws_task is None: self._receive_from_ws_task = util.create_task( self._websocket_connection.receive(), router_id=self._router_id ) if self._receive_from_queue_task is None: self._receive_from_queue_task = util.create_task( self._item_queue.get(), router_id=self._router_id ) async def _handle_pipe(self): try: while self.__is_open: self._ensure_tasks_exist() done, pending = await asyncio.wait( # noqa [self._receive_from_ws_task, self._receive_from_queue_task], return_when=asyncio.FIRST_COMPLETED, ) result = done.pop().result() if not isinstance(result, tuple): log.debug("< Processing incoming websocket message") # This has been received from the WS self._receive_from_ws_task = None await self._resolve_future(result) continue self._receive_from_queue_task = None if result[0] == self._EXIT_LITERAL: # Kill the ws connection self._receive_from_ws_task = None self.__is_open = False log.debug("Closing WS connection to %s", self.__url) return self._EXIT_LITERAL elif result[0] == "SEND_RESPONSE": log.debug("> Sending response to existing request") await self._dispatch_future(result[1]) elif result[0] == "SEND_THEN_RECEIVE_RESPONSE": log.debug("> Sending request pending future response") await self._dispatch_future(result[1]) else: log.critical("Unhandled result: %s", result) except Exception as error: log.critical("%s", "".join(traceback.format_exception(error))) async def _dispatch_future(self, data: dict): """Given a future from the queue, send it down the pipe""" if "packet_id" not in data: log.debug("Failed to resolve send packet as it was missing an id: %s", data) raise UnknownPacket packet_id = data["packet_id"] log.debug("> Sending packet %s", packet_id) await self._websocket_connection.send(json.dumps(data)) async def _resolve_future(self, raw_packet: str | bytes | dict): """Given a packet received on the pipe, resolve it's associated future""" if isinstance(raw_packet, dict): packet: PacketT = raw_packet else: packet: PacketT = json.loads(raw_packet) if "type" not in packet: log.debug("Failed to resolve packet type for %s", packet) raise UnknownPacket if "data" not in packet: log.debug( "Failed to resolve packet as it was missing the 'data' field: %s", packet, ) raise UnknownPacket if "packet_id" not in packet: log.debug("Failed to resolve packet as it was missing an id: %s", packet) raise UnknownPacket packet_id = packet["packet_id"] packet_data = packet["data"] if packet["type"] == "response": matching_future: asyncio.Future | None = self._futures.pop(packet_id, None) if matching_future is None: raise UnknownPacket matching_future.set_result(packet_data) else: # This is a request to propagate up to # the incoming request handler method if self._receive_handler is None: raise MissingReceiveHandler # Give the callee the incoming data # and function to call to return a response util.create_task( self._receive_handler( packet_data, partial(self.send_response, packet_id=packet["packet_id"]), ), router_id=self._router_id, ) def _check_for_congestion(self): """Checks if the queue appears congested and alerts""" # 50 is an arb number I picked out of thin air # Would love feedback on this size if it affects you a lot if self._item_queue.qsize() > 50: log.warning("Router queue size is over 50, possible congestion detected")