Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions libp2p/pubsub/gossipsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,9 @@ async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
raise NoPubsubAttached
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]

# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
await self.pubsub.write_msg(stream, rpc_msg)
await self.send_rpc(to_peer=peer_id, rpc=rpc_msg, urgent=False)

for topic in pubsub_msg.topicIDs:
self.time_since_last_publish[topic] = int(time.time())
Expand Down Expand Up @@ -852,11 +851,9 @@ async def handle_iwant(
sender_peer_id,
)
return
peer_stream = self.pubsub.peers[sender_peer_id]

# 4) And write the packet to the stream
await self.pubsub.write_msg(peer_stream, packet)

await self.send_rpc(to_peer=sender_peer_id, rpc=packet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The urgent bool parameter has to be mentioned here also right ?

async def handle_graft(
self, graft_msg: rpc_pb2.ControlGraft, sender_peer_id: ID
) -> None:
Expand Down Expand Up @@ -1003,13 +1000,37 @@ async def emit_control_message(

packet.control.CopyFrom(control_msg)

# Get stream for peer from pubsub
if to_peer not in self.pubsub.peers:
logger.debug(
"Fail to emit control message to %s: peer record not exist", to_peer
)
await self.send_rpc(to_peer, packet, False)

# Urgent will be true in case of IDONTWANT message
async def send_rpc(self, to_peer: ID, rpc: rpc_pb2.RPC, urgent: bool) -> None:
# TODO: Piggyback message retries

msg_bytes = rpc.SerializeToString()
msg_size = len(msg_bytes)
max_message_size = self.pubsub.maxMessageSize
if msg_size < max_message_size:
await self.do_send_rpc(rpc, to_peer, urgent)
return
peer_stream = self.pubsub.peers[to_peer]
else:
rpc_list = self.pubsub.split_rpc(pb_rpc=rpc, limit=max_message_size)
for rpc in rpc_list:
if rpc.ByteSize() > max_message_size:
self.drop_rpc(rpc)
continue
await self.do_send_rpc(rpc, to_peer, urgent)

async def do_send_rpc(self, rpc: rpc_pb2.RPC, to_peer: ID, urgent: bool) -> None:
peer_queue = self.pubsub.peer_queue[to_peer]
try:
if urgent:
await peer_queue.urgent_push(rpc=rpc, block=False)
else:
await peer_queue.push(rpc=rpc, block=False)
except Exception as e:
logger.error(f"Failed to enqueue RPC to peer {to_peer}: {e}")
self.drop_rpc(rpc)

def drop_rpc(self, rpc: rpc_pb2.RPC) -> None:
pass

# Write rpc to stream
await self.pubsub.write_msg(peer_stream, packet)
241 changes: 239 additions & 2 deletions libp2p/pubsub/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .pubsub_notifee import (
PubsubNotifee,
)
from .rpc_queue import QueueClosed, RpcQueue
from .subscription import (
TrioSubscriptionAPI,
)
Expand All @@ -87,6 +88,10 @@
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501
SUBSCRIPTION_CHANNEL_SIZE = 32

# DefaultMaximumMessageSize is 1mb.
DefaultMaxMessageSize = 1 << 20
OutBoundQueueSize = 100 + 8

logger = logging.getLogger("libp2p.pubsub")


Expand Down Expand Up @@ -138,6 +143,9 @@ class Pubsub(Service, IPubsub):
event_handle_peer_queue_started: trio.Event
event_handle_dead_peer_queue_started: trio.Event

maxMessageSize: int
peer_queue: dict[ID, RpcQueue]

def __init__(
self,
host: IHost,
Expand Down Expand Up @@ -222,6 +230,11 @@ def __init__(
self.event_handle_peer_queue_started = trio.Event()
self.event_handle_dead_peer_queue_started = trio.Event()

self.maxMessageSize = DefaultMaxMessageSize
self._sending_message_tasks = {}
#TODO: Handle deleting the values form queue.
self.peer_queue = {}

async def run(self) -> None:
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
Expand Down Expand Up @@ -366,6 +379,11 @@ def add_to_blacklist(self, peer_id: ID) -> None:
logger.debug("Added peer %s to blacklist", peer_id)
self.manager.run_task(self._teardown_if_connected, peer_id)

# Close and remove the peer's queue if it exists
queue = self.peer_queue.get(peer_id)
if queue is not None:
queue.close()

async def _teardown_if_connected(self, peer_id: ID) -> None:
"""Close their stream and remove them if connected"""
stream = self.peers.get(peer_id)
Expand Down Expand Up @@ -412,6 +430,11 @@ def clear_blacklist(self) -> None:
- Participate in topic subscriptions

"""
# Close and remove all queues for blacklisted peers
for peer_id in list(self.blacklisted_peers):
queue = self.peer_queue.get(peer_id)
if queue is not None:
queue.close()
self.blacklisted_peers.clear()
logger.debug("Cleared all peers from blacklist")

Expand Down Expand Up @@ -474,13 +497,20 @@ async def _handle_new_peer(self, peer_id: ID) -> None:
except Exception as error:
logger.debug("fail to add new peer %s, error %s", peer_id, error)
return

# Instead of using self.manager.run_daemon_task,
# spawn a background task using trio directly
# so that it is not tied to self.manager.wait_finished()
trio.lowlevel.spawn_system_task(self.handle_sending_message, peer_id, stream)
self.peers[peer_id] = stream

logger.debug("added new peer %s", peer_id)

def _handle_dead_peer(self, peer_id: ID) -> None:
if peer_id not in self.peers:
# Even if not in peers, still close and remove the queue if it exists
queue = self.peer_queue.get(peer_id)
if queue is not None:
queue.close()
return
del self.peers[peer_id]

Expand All @@ -490,6 +520,11 @@ def _handle_dead_peer(self, peer_id: ID) -> None:

self.router.remove_peer(peer_id)

# Close and remove the peer's queue if it exists
queue = self.peer_queue.get(peer_id)
if queue is not None:
queue.close()

logger.debug("removed dead peer %s", peer_id)

async def handle_peer_queue(self) -> None:
Expand Down Expand Up @@ -859,4 +894,206 @@ async def write_msg(self, stream: INetStream, rpc_msg: rpc_pb2.RPC) -> bool:
peer_id = stream.muxed_conn.peer_id
logger.debug("Fail to write message to %s: stream closed", peer_id)
self._handle_dead_peer(peer_id)
return False


async def handle_sending_message(self, to_peer: ID, stream: INetStream) -> None:
if to_peer in self._sending_message_tasks:
return
self._sending_message_tasks[to_peer] = True
try:
if to_peer not in self.peer_queue:
queue = RpcQueue(OutBoundQueueSize)
self.peer_queue[to_peer] = queue
else:
queue = self.peer_queue[to_peer]

while True:
try:
rpc_msg: rpc_pb2.RPC = await queue.pop()
await self.write_msg(stream, rpc_msg)
except QueueClosed:
logger.error("The queue is already closed.")
break
except Exception as e:
logger.exception("Exception in handle_sending_message \
for peer %s: %s", to_peer, e)
break
finally:
self._sending_message_tasks.pop(to_peer, None)

def size_of_embedded_msg(self, msg_size: int) -> int:
def sov_rpc(x: int) -> int:
if x == 0:
return 1
return ((x.bit_length() + 6) // 7)

prefix_size = sov_rpc(msg_size)
return prefix_size + msg_size

def split_rpc(self, pb_rpc: rpc_pb2.RPC, limit: int) -> list[rpc_pb2.RPC]:
"""
Splits the given pb_rpc into a list of RPCs, each not exceeding the
given size limit.
If a sub-message is too large to fit, it will be returned as an
oversized RPC.
"""
result: list[rpc_pb2.RPC] = []

def base_rpc() -> rpc_pb2.RPC:
return rpc_pb2.RPC()

# Split Publish messages
publish_msgs = pb_rpc.publish
n = len(publish_msgs)
if n > 0:
msg_sizes = [msg.ByteSize() for msg in publish_msgs]
incr_sizes = [1 + self.size_of_embedded_msg(sz) for sz in msg_sizes]
i = 0
while i < n:
new_rpc = base_rpc()
size = 0
j = i
while j < n and size + incr_sizes[j] <= limit:
size += incr_sizes[j]
j += 1
if j > i:
new_rpc.publish.extend(publish_msgs[i:j])
result.append(new_rpc)
i = j

# if the rest of the RPC (without publish) fits, add it
rest_rpc = base_rpc()
rest_rpc.CopyFrom(pb_rpc)
while rest_rpc.publish:
rest_rpc.publish.pop()
if rest_rpc.ByteSize() < limit and rest_rpc.ByteSize() > 0:
result.append(rest_rpc)
return result

# Split subscriptions
subs = pb_rpc.subscriptions
n = len(subs)
if n > 0:
sub_sizes = [subs[i].ByteSize() for i in range(n)]
incr_sizes = [1 + self.size_of_embedded_msg(sz) for sz in sub_sizes]
i = 0
while i < n:
new_rpc = base_rpc()
size = 0
j = i
while j < n and size + incr_sizes[j] <= limit:
size += incr_sizes[j]
j += 1
if j > i:
new_rpc.subscriptions.extend(subs[i:j])
result.append(new_rpc)
i = j

# Split control grafts
ctl = pb_rpc.control
if ctl is not None and ctl.ByteSize() > 0:
grafts = list(ctl.graft)
i = 0
while i < len(grafts):
new_rpc = base_rpc()
new_rpc.control.CopyFrom(rpc_pb2.ControlMessage())
size = 0
j = i
while j < len(grafts):
graft = grafts[j]
new_rpc.control.graft.extend([graft])
incremental_size = new_rpc.ByteSize()
if size + incremental_size > limit:
if len(new_rpc.control.graft) > 1:
new_rpc.control.graft.pop()
result.append(new_rpc)
break
size += incremental_size
j += 1
i = j

# Split control prunes
prunes = list(ctl.prune)
i = 0
while i < len(prunes):
new_rpc = base_rpc()
new_rpc.control.CopyFrom(rpc_pb2.ControlMessage())
size = 0
j = i
while j < len(prunes):
prune = prunes[j]
new_rpc.control.prune.extend([prune])
incremental_size = new_rpc.ByteSize()
if size + incremental_size > limit:
if len(new_rpc.control.prune) > 1:
new_rpc.control.prune.pop()
result.append(new_rpc)
break
size += incremental_size
j += 1
i = j

# Split control iwant
iwants = list(ctl.iwant)
all_msg_ids = []
for iwant in iwants:
all_msg_ids.extend(iwant.messageIDs)

k = 0
while k < len(all_msg_ids):
new_rpc = base_rpc()
new_rpc.control.CopyFrom(rpc_pb2.ControlMessage())
new_iwant = rpc_pb2.ControlIWant()
size = 0
current_index = k
while current_index < len(all_msg_ids):
msg_id = all_msg_ids[current_index]
new_iwant.messageIDs.append(msg_id)
incremental_size = new_rpc.ByteSize() + new_iwant.ByteSize()
if size + incremental_size > limit:
if len(new_iwant.messageIDs) > 1:
new_iwant.messageIDs.pop()
new_rpc.control.iwant.extend([new_iwant])
result.append(new_rpc)
break
size += incremental_size
current_index += 1
if new_iwant.messageIDs:
new_rpc.control.iwant.extend([new_iwant])
result.append(new_rpc)
k = current_index

# Split control ihave
ihaves = list(ctl.ihave)
for ihave in ihaves:
topic_id = ihave.topicID
msg_ids = list(ihave.messageIDs)
k = 0
while k < len(msg_ids):
new_rpc = base_rpc()
new_rpc.control.CopyFrom(rpc_pb2.ControlMessage())
new_ihave = rpc_pb2.ControlIHave()
new_ihave.topicID = topic_id
size = 0
current_index = k
while current_index < len(msg_ids):
msg_id = msg_ids[current_index]
new_ihave.messageIDs.extend([msg_id])
incremental_size = new_rpc.ByteSize()
if size + incremental_size > limit:
if len(new_ihave.messageIDs) > 1:
new_ihave.messageIDs.pop()
new_rpc.control.ihave.extend([new_ihave])
result.append(new_rpc)
break
size += incremental_size
current_index += 1
k = current_index

# If nothing was added, but the original RPC is non-empty, add it as is
if not result and pb_rpc.ByteSize() > 0:
result.append(pb_rpc)
if result and result[-1].ByteSize() == 0:
result.pop()

return result
Loading
Loading