Source code for titanfe.connection
#
# Copyright (c) 2019-present, wobe-systems GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# found in the LICENSE file in the root directory of this source tree.
#
"""Encapsulate asyncio connections by wrapping them into a Connection"""
import asyncio
import logging
import pickle
from collections import namedtuple
from typing import Optional
from ujotypes import UjoMap, read_buffer, ujo_to_python, UjoStringUTF8
import titanfe.log
from titanfe.apps.brick_runner.connection import Buffer
from titanfe.ujo_helper import py_to_ujo_bytes
from titanfe.messages import Message
ENCODING = "UJO"
# ENCODING = "PICKLE"
PAYLOAD = UjoStringUTF8("payload")
BUFFER = UjoStringUTF8("buffer")
NetworkAddress = namedtuple("NetworkAddress", ("host", "port"))
[docs]def decode_ujo_message(ujo_bytes):
"""Decode ujo bytes into a corresponding python object, but keep an existing "Payload" as Ujo.
"""
ujoobj = read_buffer(ujo_bytes)
_, content = ujoobj[0], ujoobj[1]
payload = None
if isinstance(content, UjoMap) and PAYLOAD in content:
payload = content[PAYLOAD]
del ujoobj[1][PAYLOAD]
try:
buffer = content[BUFFER]
except KeyError:
buffer = UjoMap()
else:
del ujoobj[1][BUFFER]
pyobj = ujo_to_python(ujoobj)
if payload is not None:
# set payload to the original ujo payload
pyobj[1]["payload"] = payload
pyobj[1]["buffer"] = Buffer(buffer)
return pyobj
[docs]class Connection:
"""Wrap an asyncio StreamReader/Writer combination into a connection object.
Arguments:
reader (asyncio.StreamReader): the stream reader
writer (asyncio.StreamWriter): the stream writer
log (logging.logger): a parent logger
encoding: "PICKLE" or "UJO"
"""
def __init__(self, reader, writer, log=None, encoding=ENCODING):
self.reader = reader
self.writer = writer
self.closed = False
self.log = log.getChild("Connection") if log else titanfe.log.getLogger(__name__)
if encoding == "PICKLE":
self.decode = pickle.loads
self.encode = pickle.dumps
elif encoding == "UJO":
self.decode = decode_ujo_message
self.encode = py_to_ujo_bytes
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
[docs] @classmethod
async def open(
cls, address: NetworkAddress, log: Optional[logging.Logger] = None
) -> "Connection":
"""open an asyncio connection to the given address (host, port)"""
reader, writer = await asyncio.open_connection(*address)
return cls(reader, writer, log)
[docs] async def close(self):
"""close the connection by closing it's reader and writer"""
if self.closed:
return
self.reader.feed_eof()
self.writer.close()
try:
await self.writer.wait_closed()
except (ConnectionAbortedError, ConnectionResetError):
pass
self.closed = True
[docs] async def receive(self):
"""wait until a message comes through and return it's content after decoding
Return:
Message: a message or None if the connection was closed remotely
"""
message = None
while not message:
try:
msg_len = await self.reader.readexactly(4)
except (asyncio.IncompleteReadError, ConnectionError):
self.log.debug("Stream at EOF - close connection.")
# self.log.debug('', exc_info=True)
await self.close()
return
msg = await self.reader.readexactly(int.from_bytes(msg_len, "big"))
self.log.debug("received message: %s", msg)
try:
msg = self.decode(msg)
except Exception:
self.log.error("Failed to decode %r", msg, exc_info=True)
raise ValueError(f"Failed to decode {msg}")
try:
message = Message(*msg)
except TypeError:
self.log.error("Received unknown Message format: %s", msg)
message = None
self.log.debug("decoded message: %r", message)
return message
[docs] async def send(self, message):
"""encode and send the content as a message"""
self.log.debug("sending: %r", message)
try:
msg = self.encode(message)
except Exception:
self.log.error("Failed to encode %r", message, exc_info=True)
raise ValueError(f"Failed to encode {message}")
msg_len = len(msg).to_bytes(4, "big")
try:
self.writer.write(msg_len)
self.writer.write(msg)
await self.writer.drain()
except (ConnectionAbortedError, ConnectionResetError):
await self.close()
def __aiter__(self):
return self
async def __anext__(self):
message = await self.receive()
if not message:
raise StopAsyncIteration
return message