diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 45c6cd815..60425204f 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -275,10 +275,9 @@ async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: raise NoPubsubAttached if peer_id not in self.pubsub.peers: continue - stream = self.pubsub.peers[peer_id] # TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages. - await self.pubsub.write_msg(stream, rpc_msg) + await self.send_rpc(to_peer=peer_id, rpc=rpc_msg, urgent=False) for topic in pubsub_msg.topicIDs: self.time_since_last_publish[topic] = int(time.time()) @@ -852,11 +851,9 @@ async def handle_iwant( sender_peer_id, ) return - peer_stream = self.pubsub.peers[sender_peer_id] # 4) And write the packet to the stream - await self.pubsub.write_msg(peer_stream, packet) - + await self.send_rpc(to_peer=sender_peer_id, rpc=packet) async def handle_graft( self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID ) -> None: @@ -1003,13 +1000,37 @@ async def emit_control_message( packet.control.CopyFrom(control_msg) - # Get stream for peer from pubsub - if to_peer not in self.pubsub.peers: - logger.debug( - "Fail to emit control message to %s: peer record not exist", to_peer - ) + await self.send_rpc(to_peer, packet, False) + + # Urgent will be true in case of IDONTWANT message + async def send_rpc(self, to_peer: ID, rpc: rpc_pb2.RPC, urgent: bool) -> None: + # TODO: Piggyback message retries + + msg_bytes = rpc.SerializeToString() + msg_size = len(msg_bytes) + max_message_size = self.pubsub.maxMessageSize + if msg_size < max_message_size: + await self.do_send_rpc(rpc, to_peer, urgent) return - peer_stream = self.pubsub.peers[to_peer] + else: + rpc_list = self.pubsub.split_rpc(pb_rpc=rpc, limit=max_message_size) + for rpc in rpc_list: + if rpc.ByteSize() > max_message_size: + self.drop_rpc(rpc) + continue + await self.do_send_rpc(rpc, to_peer, urgent) + + async def do_send_rpc(self, rpc: rpc_pb2.RPC, to_peer: ID, urgent: bool) -> None: + peer_queue = self.pubsub.peer_queue[to_peer] + try: + if urgent: + await peer_queue.urgent_push(rpc=rpc, block=False) + else: + await peer_queue.push(rpc=rpc, block=False) + except Exception as e: + logger.error(f"Failed to enqueue RPC to peer {to_peer}: {e}") + self.drop_rpc(rpc) + + def drop_rpc(self, rpc: rpc_pb2.RPC) -> None: + pass - # Write rpc to stream - await self.pubsub.write_msg(peer_stream, packet) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 2c605fc3a..527b13190 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -76,6 +76,7 @@ from .pubsub_notifee import ( PubsubNotifee, ) +from .rpc_queue import QueueClosed, RpcQueue from .subscription import ( TrioSubscriptionAPI, ) @@ -87,6 +88,10 @@ # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501 SUBSCRIPTION_CHANNEL_SIZE = 32 +# DefaultMaximumMessageSize is 1mb. +DefaultMaxMessageSize = 1 << 20 +OutBoundQueueSize = 100 + 8 + logger = logging.getLogger("libp2p.pubsub") @@ -138,6 +143,9 @@ class Pubsub(Service, IPubsub): event_handle_peer_queue_started: trio.Event event_handle_dead_peer_queue_started: trio.Event + maxMessageSize: int + peer_queue: dict[ID, RpcQueue] + def __init__( self, host: IHost, @@ -222,6 +230,11 @@ def __init__( self.event_handle_peer_queue_started = trio.Event() self.event_handle_dead_peer_queue_started = trio.Event() + self.maxMessageSize = DefaultMaxMessageSize + self._sending_message_tasks = {} + #TODO: Handle deleting the values form queue. + self.peer_queue = {} + async def run(self) -> None: self.manager.run_daemon_task(self.handle_peer_queue) self.manager.run_daemon_task(self.handle_dead_peer_queue) @@ -366,6 +379,11 @@ def add_to_blacklist(self, peer_id: ID) -> None: logger.debug("Added peer %s to blacklist", peer_id) self.manager.run_task(self._teardown_if_connected, peer_id) + # Close and remove the peer's queue if it exists + queue = self.peer_queue.get(peer_id) + if queue is not None: + queue.close() + async def _teardown_if_connected(self, peer_id: ID) -> None: """Close their stream and remove them if connected""" stream = self.peers.get(peer_id) @@ -412,6 +430,11 @@ def clear_blacklist(self) -> None: - Participate in topic subscriptions """ + # Close and remove all queues for blacklisted peers + for peer_id in list(self.blacklisted_peers): + queue = self.peer_queue.get(peer_id) + if queue is not None: + queue.close() self.blacklisted_peers.clear() logger.debug("Cleared all peers from blacklist") @@ -474,13 +497,20 @@ async def _handle_new_peer(self, peer_id: ID) -> None: except Exception as error: logger.debug("fail to add new peer %s, error %s", peer_id, error) return - + # Instead of using self.manager.run_daemon_task, + # spawn a background task using trio directly + # so that it is not tied to self.manager.wait_finished() + trio.lowlevel.spawn_system_task(self.handle_sending_message, peer_id, stream) self.peers[peer_id] = stream logger.debug("added new peer %s", peer_id) def _handle_dead_peer(self, peer_id: ID) -> None: if peer_id not in self.peers: + # Even if not in peers, still close and remove the queue if it exists + queue = self.peer_queue.get(peer_id) + if queue is not None: + queue.close() return del self.peers[peer_id] @@ -490,6 +520,11 @@ def _handle_dead_peer(self, peer_id: ID) -> None: self.router.remove_peer(peer_id) + # Close and remove the peer's queue if it exists + queue = self.peer_queue.get(peer_id) + if queue is not None: + queue.close() + logger.debug("removed dead peer %s", peer_id) async def handle_peer_queue(self) -> None: @@ -859,4 +894,206 @@ async def write_msg(self, stream: INetStream, rpc_msg: rpc_pb2.RPC) -> bool: peer_id = stream.muxed_conn.peer_id logger.debug("Fail to write message to %s: stream closed", peer_id) self._handle_dead_peer(peer_id) - return False + + + async def handle_sending_message(self, to_peer: ID, stream: INetStream) -> None: + if to_peer in self._sending_message_tasks: + return + self._sending_message_tasks[to_peer] = True + try: + if to_peer not in self.peer_queue: + queue = RpcQueue(OutBoundQueueSize) + self.peer_queue[to_peer] = queue + else: + queue = self.peer_queue[to_peer] + + while True: + try: + rpc_msg: rpc_pb2.RPC = await queue.pop() + await self.write_msg(stream, rpc_msg) + except QueueClosed: + logger.error("The queue is already closed.") + break + except Exception as e: + logger.exception("Exception in handle_sending_message \ + for peer %s: %s", to_peer, e) + break + finally: + self._sending_message_tasks.pop(to_peer, None) + + def size_of_embedded_msg(self, msg_size: int) -> int: + def sov_rpc(x: int) -> int: + if x == 0: + return 1 + return ((x.bit_length() + 6) // 7) + + prefix_size = sov_rpc(msg_size) + return prefix_size + msg_size + + def split_rpc(self, pb_rpc: rpc_pb2.RPC, limit: int) -> list[rpc_pb2.RPC]: + """ + Splits the given pb_rpc into a list of RPCs, each not exceeding the + given size limit. + If a sub-message is too large to fit, it will be returned as an + oversized RPC. + """ + result: list[rpc_pb2.RPC] = [] + + def base_rpc() -> rpc_pb2.RPC: + return rpc_pb2.RPC() + + # Split Publish messages + publish_msgs = pb_rpc.publish + n = len(publish_msgs) + if n > 0: + msg_sizes = [msg.ByteSize() for msg in publish_msgs] + incr_sizes = [1 + self.size_of_embedded_msg(sz) for sz in msg_sizes] + i = 0 + while i < n: + new_rpc = base_rpc() + size = 0 + j = i + while j < n and size + incr_sizes[j] <= limit: + size += incr_sizes[j] + j += 1 + if j > i: + new_rpc.publish.extend(publish_msgs[i:j]) + result.append(new_rpc) + i = j + + # if the rest of the RPC (without publish) fits, add it + rest_rpc = base_rpc() + rest_rpc.CopyFrom(pb_rpc) + while rest_rpc.publish: + rest_rpc.publish.pop() + if rest_rpc.ByteSize() < limit and rest_rpc.ByteSize() > 0: + result.append(rest_rpc) + return result + + # Split subscriptions + subs = pb_rpc.subscriptions + n = len(subs) + if n > 0: + sub_sizes = [subs[i].ByteSize() for i in range(n)] + incr_sizes = [1 + self.size_of_embedded_msg(sz) for sz in sub_sizes] + i = 0 + while i < n: + new_rpc = base_rpc() + size = 0 + j = i + while j < n and size + incr_sizes[j] <= limit: + size += incr_sizes[j] + j += 1 + if j > i: + new_rpc.subscriptions.extend(subs[i:j]) + result.append(new_rpc) + i = j + + # Split control grafts + ctl = pb_rpc.control + if ctl is not None and ctl.ByteSize() > 0: + grafts = list(ctl.graft) + i = 0 + while i < len(grafts): + new_rpc = base_rpc() + new_rpc.control.CopyFrom(rpc_pb2.ControlMessage()) + size = 0 + j = i + while j < len(grafts): + graft = grafts[j] + new_rpc.control.graft.extend([graft]) + incremental_size = new_rpc.ByteSize() + if size + incremental_size > limit: + if len(new_rpc.control.graft) > 1: + new_rpc.control.graft.pop() + result.append(new_rpc) + break + size += incremental_size + j += 1 + i = j + + # Split control prunes + prunes = list(ctl.prune) + i = 0 + while i < len(prunes): + new_rpc = base_rpc() + new_rpc.control.CopyFrom(rpc_pb2.ControlMessage()) + size = 0 + j = i + while j < len(prunes): + prune = prunes[j] + new_rpc.control.prune.extend([prune]) + incremental_size = new_rpc.ByteSize() + if size + incremental_size > limit: + if len(new_rpc.control.prune) > 1: + new_rpc.control.prune.pop() + result.append(new_rpc) + break + size += incremental_size + j += 1 + i = j + + # Split control iwant + iwants = list(ctl.iwant) + all_msg_ids = [] + for iwant in iwants: + all_msg_ids.extend(iwant.messageIDs) + + k = 0 + while k < len(all_msg_ids): + new_rpc = base_rpc() + new_rpc.control.CopyFrom(rpc_pb2.ControlMessage()) + new_iwant = rpc_pb2.ControlIWant() + size = 0 + current_index = k + while current_index < len(all_msg_ids): + msg_id = all_msg_ids[current_index] + new_iwant.messageIDs.append(msg_id) + incremental_size = new_rpc.ByteSize() + new_iwant.ByteSize() + if size + incremental_size > limit: + if len(new_iwant.messageIDs) > 1: + new_iwant.messageIDs.pop() + new_rpc.control.iwant.extend([new_iwant]) + result.append(new_rpc) + break + size += incremental_size + current_index += 1 + if new_iwant.messageIDs: + new_rpc.control.iwant.extend([new_iwant]) + result.append(new_rpc) + k = current_index + + # Split control ihave + ihaves = list(ctl.ihave) + for ihave in ihaves: + topic_id = ihave.topicID + msg_ids = list(ihave.messageIDs) + k = 0 + while k < len(msg_ids): + new_rpc = base_rpc() + new_rpc.control.CopyFrom(rpc_pb2.ControlMessage()) + new_ihave = rpc_pb2.ControlIHave() + new_ihave.topicID = topic_id + size = 0 + current_index = k + while current_index < len(msg_ids): + msg_id = msg_ids[current_index] + new_ihave.messageIDs.extend([msg_id]) + incremental_size = new_rpc.ByteSize() + if size + incremental_size > limit: + if len(new_ihave.messageIDs) > 1: + new_ihave.messageIDs.pop() + new_rpc.control.ihave.extend([new_ihave]) + result.append(new_rpc) + break + size += incremental_size + current_index += 1 + k = current_index + + # If nothing was added, but the original RPC is non-empty, add it as is + if not result and pb_rpc.ByteSize() > 0: + result.append(pb_rpc) + if result and result[-1].ByteSize() == 0: + result.pop() + + return result diff --git a/libp2p/pubsub/rpc_queue.py b/libp2p/pubsub/rpc_queue.py new file mode 100644 index 000000000..7842652bd --- /dev/null +++ b/libp2p/pubsub/rpc_queue.py @@ -0,0 +1,87 @@ +from typing import Any, List, Optional + +import trio + + +class QueueClosed(Exception): + pass + +class QueueFull(Exception): + pass + +class QueuePushOnClosed(Exception): + pass + +class QueueCancelled(Exception): + pass + +class PriorityQueue: + def __init__(self) -> None: + self.normal: List[Any] = [] + self.priority: List[Any] = [] + + def __len__(self) -> int: + return len(self.normal) + len(self.priority) + + def normal_push(self, rpc: Any) -> None: + self.normal.append(rpc) + + def priority_push(self, rpc: Any) -> None: + self.priority.append(rpc) + + def pop(self) -> Optional[Any]: + if self.priority: + return self.priority.pop(0) + elif self.normal: + return self.normal.pop(0) + return None + +class RpcQueue: + def __init__(self, max_size: int) -> None: + self.queue: PriorityQueue = PriorityQueue() + self.max_size: int = max_size + self.closed: bool = False + self._lock = trio.Lock() + self._space_available = trio.Condition(self._lock) + + async def push(self, rpc: Any, block: bool = True) -> None: + await self._push(rpc, urgent=False, block=block) + + async def urgent_push(self, rpc: Any, block: bool = True) -> None: + await self._push(rpc, urgent=True, block=block) + + async def _push(self, rpc: Any, urgent: bool, block: bool) -> None: + async with self._lock: + if self.closed: + raise QueuePushOnClosed("push on closed rpc queue") + while len(self.queue) == self.max_size: + if block: + await self._space_available.wait() + if self.closed: + raise QueuePushOnClosed("push on closed rpc queue") + else: + raise QueueFull("rpc queue full") + if urgent: + self.queue.priority_push(rpc) + else: + self.queue.normal_push(rpc) + self._space_available.notify() + + async def pop(self) -> Any: + while True: + async with self._lock: + if self.closed: + raise QueueClosed("rpc queue closed") + if len(self.queue) > 0: + rpc = self.queue.pop() + self._space_available.notify() + return rpc + # If queue is empty, wait for a message or closure + await self._space_available.wait() + if self.closed: + raise QueueClosed("rpc queue closed") + + async def close(self) -> None: + async with self._lock: + self.closed = True + self._space_available.notify_all() diff --git a/newsfragments/913.internal.rst b/newsfragments/913.internal.rst new file mode 100644 index 000000000..c9ed0ef1e --- /dev/null +++ b/newsfragments/913.internal.rst @@ -0,0 +1 @@ +Add rpc_message queuing and splitting for gossipsub