diff --git a/server/recceiver2/__init__.py b/server/recceiver2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/recceiver2/app.py b/server/recceiver2/app.py new file mode 100644 index 0000000..552707a --- /dev/null +++ b/server/recceiver2/app.py @@ -0,0 +1,230 @@ +# See files LICENSE and COPYRIGHT +# SPDX-License-Identifier: EPICS + +import asyncio +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +import logging +from random import randint +import socket +from typing import Dict, Set, Tuple +from weakref import WeakSet + +import proto +from .proto import readmsg, protoID +from .conf import ConfigParser, parse_ep + +_log = logging.getLogger(__name__) + +RecID = int +RecName = RecType = str +Infos = Dict[str, str] + +class Op(Enum): + Update = 0 + Disconnect = 1 + +@dataclass +class Transaction: + """Batch of updates (or Disconnect) from a client + """ + op: Op + # Ignore remaining if op is Disconnect + + # Connected IP:port + peer: Tuple[str, int] + info: Infos = field(default_factory=dict) + add_record: Dict[RecID, Tuple[RecName, RecType]] = field(default_factory=dict) + del_record: Set[RecID] = field(default_factory=set) + record_info: Dict[RecID, Infos] = field(default_factory=defaultdict(dict)) + +class Recceiver: + cfg : ConfigParser + key : int + announcer : asyncio.Task + client : Set["ClientHandler"] + + @classmethod + async def start(klass, cfg: ConfigParser): + R = klass(cfg) + + _log.debug('Starting %r', R) + await R.listener() + await R.announcer() + _log.debug('Started %r', R) + + return R + + def __init__(self, cfg): + self.cfg = cfg + # pick a random key to distinguish this instance + self.key = randint(0,0xffffffff) + self.clients = WeakSet() + self.maxActive = asyncio.Semaphore(int(cfg['maxActive'])) + self.tcptimeout = float(cfg['tcptimeout']) + self.commitSizeLimit = int(cfg['commitSizeLimit']) + self.commitInterval = int(cfg['commitInterval']) + + async def close(self): + _log.debug('Stopping %r', self) + + # first, stop announcing + self.announcer.cancel() + try: + await self.announcer + except asyncio.CancelledError: + pass + + # stop accepting new connections + self.server.close() + await self.server.wait_closed() + + # close existing connections + clients, self.clients = set(self.clients), None + # spoil self.clients because of possible race with pending new_client() callback + for C in clients: + C.writer.close() + for C in clients: + await C.writer.wait_closed() + + _log.debug('Stopped %r', self) + + def __aenter__(self): + pass + + def __aexit__(self,A,B,C): + await self.close() + + async def announcer(self): + "Start announcer task" + + # digest configuration and prepare before launching Task + # so that any error is immediate + + announceInterval = float(self.cfg['announceInterval']) + + dests = [parse_ep(ep, defport=5049) for ep in self.cfg['addrlist'].split(',')] + + # bind the same interface as the TCP socket, with a random port + local_addr = (self.local_addr[0], 0) + + tr, _proto = await asyncio.get_running_loop() \ + .create_datagram_endpoint(asyncio.DatagramProtocol, reuse_address=True, + local_addr=local_addr) + + # since the announcement message is static, prepare it now + msg = proto.Announce.pack( + protoID, + 0, + socket.inet_aton(self.local_addr[0]), + self.local_addr[1], + self.key + ) + + self.announcer = asyncio.create_task(self.announcer_loop(dests, tr, msg, announceInterval), "Announcer") + + async def announcer_loop(self, dests, tr, msg, announceInterval): + try: + while True: + _log.debug('Ping') + for d in dests: + try: + tr.sendto(msg, d) + except: # TODO: ignore / info / warn to reduce error spam (eg. destination unreachable) + _log.exception('UDP Send error') + + await asyncio.sleep(announceInterval) + except: + _log.exception('Announcer fails') + raise + + async def listener(self): + "Start TCP listener" + local_addr = parse_ep(self.cfg['bind']) + + self.server = await asyncio.start_server(self.new_client, + host=local_addr[0], port=local_addr[1]) + + # find endpoint (w/ port#) actually bound + self.local_addr = self.server.sockets[0].sockets[0].getsockname()[:2] + + async def new_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + # we are already in a Task + C = ClientHandler(self, reader, writer) + self.clients.add(C) + await C.handle() + +@dataclass +class ClientHandler: + serv : Recceiver + reader : asyncio.StreamReader + writer : asyncio.StreamWriter + peer : Tuple[str, int] = None + active : Transaction = None + activeSize : int = field(default=0) + + def __post_init__(self): + self.peer = self.writer.get_extra_info('peername') + self.active = Transaction(Op.Update, self.peer) + + async def handle(self): + try: + # initially waiting for client greeting + msg = await readmsg(self.reader, server=False, timeout=self.serv.tcptimeout) + if not isinstance(msg, proto.ClientGreet): + raise RuntimeError("Protocol Violation") + + if msg.key!=self.server.key: + # client acting on an announcement with a different key (maybe we just restarted?) + _log.warn("Client w/ stale key %s != %s", msg.key, self.server.key) + self.writer.close() + yield self.writer.wait_closed() + return + + # limit the number of clients concurrently dumping + # to ~bound our resource usage + with self.server.maxActive: + # send greeting to provoke client to begin dumping + self.writer.write(proto.ServerGreet(0).encode()) + + while True: + msg = await readmsg(self.reader, server=False, timeout=self.serv.tcptimeout) + if isinstance(msg, proto.ClientDone): + break + + self.handle_msg(msg) + + while True: + msg = await readmsg(self.reader, server=False, timeout=self.serv.tcptimeout) + self.handle_msg(msg) + + while True: + if not self.active: + msg = readmsg(self.reader) + + except: + _log.exception("Error from %s", self.peer) + self.writer.close() + # TODO: commit Transaction(Op.Disconnect) + raise + + def handle_msg(self, msg: proto.Message): + if isinstance(msg, proto.ClientAddRecord): + self.active.add_record[msg.recid] = (msg.rname, msg.rtype) + + elif isinstance(msg, proto.ClientInfo): + if msg.recid==0: + self.active.info[msg.key] = msg.val + else: + self.active.record_info[msg.recid][msg.key] = msg.val + + else: + return # ignore unexpected, but valid, messages + + self.activeSize += 1 + + if self.commitSizeLimit and self.activeSize>=self.commitSizeLimit: + pass # TODO: commit now + elif self.activeSize==1: + pass # TODO: start commit interval timer diff --git a/server/recceiver2/conf.py b/server/recceiver2/conf.py new file mode 100644 index 0000000..afe74c1 --- /dev/null +++ b/server/recceiver2/conf.py @@ -0,0 +1,29 @@ +# See files LICENSE and COPYRIGHT +# SPDX-License-Identifier: EPICS + +from configparser import ConfigParser + +__all__ = ( + 'loadConfig', +) + +def loadConfig(fname : str) -> ConfigParser: + P = ConfigParser() + + P['recceiver'] = { + 'announceInterval': '30.0', + 'tcptimeout': '15.0', + 'commitInterval': '5.0', + 'commitSizeLimit': str(16*1024), + 'maxActive': '20', + 'addrlist': '', + 'bind': '0.0.0.0:0', + } + + with open(fname, 'r') as F: + P.read(F) + +def parse_ep(s, *, defport=0): + addr, _sep, port = s.partition(':') + + return (addr, int(port or defport)) diff --git a/server/recceiver2/proto.py b/server/recceiver2/proto.py new file mode 100644 index 0000000..a12f3a2 --- /dev/null +++ b/server/recceiver2/proto.py @@ -0,0 +1,184 @@ +# See files LICENSE and COPYRIGHT +# SPDX-License-Identifier: EPICS + +import asyncio +from dataclasses import dataclass +import logging +import struct +import socket + +_log = logging.getLogger(__name__) + +# Protocol ID +protoID = 0x5243 + +## UDP Protocol ## + +# (protoid, 0, addr, port, servKey) +Announce = struct.Struct('>HH4sHxxI') +assert Announce.size==16 + +## TCP Protocol ## + +# (protoid, msgid, bodylen) +Header = struct.Struct('>HHI') +assert Header.size==8 + +class Message: + id: int + msg: struct.Struct + __slots__ = () + + @classmethod + def decode(klass, body : bytes): + return klass(*klass.msg.unpack(body[:klass.msg.size])) + + def encode(self) -> bytes: + return self.msg.pack(*self) + +@dataclass +class ServerPing(Message): + id = 0x8002 + nonce = int + msg = struct.Struct('>I') + assert msg.size==4 + __slots__ = () + +@dataclass +class ClientPong(Message): + id = 0x0002 + nonce = int + msg = struct.Struct('>I') + assert msg.size==4 + __slots__ = () + +@dataclass +class ServerGreet(Message): + id = 0x8001 + msg = struct.Struct('>B') + assert msg.size==1 + zero : int + __slots__ = () + +@dataclass +class ClientGreet(Message): + id = 0x0001 + msg = struct.Struct('>HxxI') + assert msg.size==8 + zero : int + key : int + __slots__ = () + +@dataclass +class ClientInfo(Message): + id = 0x0006 + msg = struct.Struct('>IBxH') + assert msg.size==9 + recid : int + key : str + val : str + __slots__ = () + + @classmethod + def decode(klass, body : bytes): + recid, keylen, vallen = klass.msg.unpack(body[:klass.msg.size]) + key = body[klass.msg.size:klass.msg.size+keylen].decode() + val = body[klass.msg.size+keylen:klass.msg.size+keylen+vallen].decode() + return klass(recid, key, val) + + def encode(self) -> bytes: + return b''.join(( + self.msg.pack(self.recid, len(self.key), len(self.val)), + self.key, + self.val, + )) + +@dataclass +class ClientDone(Message): + id = 0x0005 + __slots__ = () + + @classmethod + def decode(klass, body : bytes): + return klass() + + def encode(self) -> bytes: + return b'' + +@dataclass +class ClientAddRecord(Message): + id = 0x0003 + msg = struct.Struct('>IBBH') + assert msg.size==8 + recid : int + atype : int + rtype : str + rname : str + __slots__ = () + + @classmethod + def decode(klass, body : bytes): + recid, atype, tlen, nlen = klass.msg.unpack(body[:klass.msg.size]) + rtype = body[klass.msg.size:klass.msg.size+tlen].decode() + rname = body[klass.msg.size+tlen:klass.msg.size+tlen+nlen].decode() + return klass(recid, atype, rtype, rname) + + def encode(self) -> bytes: + return b''.join(( + self.msg.pack(self.recid, self.atype, len(self.rtype), len(self.rname)), + self.rtype, + self.rname, + )) + +@dataclass +class ClientDelRecord(Message): + id = 0x0001 + msg = struct.Struct('>I') + assert msg.size==4 + recid : int + __slots__ = () + +messages = ( + ServerPing, + ClientPong, + ServerGreet, + ClientGreet, + ClientInfo, + ClientAddRecord, + ClientDelRecord, + ClientDone, +) +messages = {msg.id: msg for msg in messages} + +async def _readmsg(reader: asyncio.StreamReader, server=False) -> Message: + while True: + ID, msg, blen = Header.unpack(await reader.readexactly(Header.size)) + if ID!=protoID or server ^ (ID >= 0x8000): + raise RuntimeError("Header error") + body = await reader.readexactly(blen) + + try: + Msg = messages[msg] + except KeyError: + continue + + return Msg.decode(body) + + +async def readmsg(reader: asyncio.StreamReader, server=False, timeout=None) -> Message: + return (await asyncio.wait_for(_readmsg(reader, server), timeout=timeout)) + +class UDPListener(asyncio.DatagramProtocol): + def datagram_received(self, data, src): + ID, zero, addr4, portn, key = Announce.unpack(data[:Announce.size]) + if ID!=protoID or zero!=0: + return + addr4 = socket.inet_ntoa(addr4) + + self.announcement(ep=(addr4, portn), key=key) + + def error_received(self, e): + try: + raise e + except Exception: + _log.exception('UDP Socket Error')