Skip to content

Commit 9fdb36e

Browse files
authored
Merge branch 'main' into keyerror-fix
2 parents f80101c + 9370101 commit 9fdb36e

File tree

5 files changed

+219
-8
lines changed

5 files changed

+219
-8
lines changed

libp2p/custom_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@
3737
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
3838
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
3939
UnsubscribeFn = Callable[[], Awaitable[None]]
40+
MessageID = NewType("MessageID", str)

libp2p/pubsub/gossipsub.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from ast import (
2-
literal_eval,
3-
)
41
from collections import (
52
defaultdict,
63
)
@@ -22,6 +19,7 @@
2219
IPubsubRouter,
2320
)
2421
from libp2p.custom_types import (
22+
MessageID,
2523
TProtocol,
2624
)
2725
from libp2p.peer.id import (
@@ -56,6 +54,10 @@
5654
from .pubsub import (
5755
Pubsub,
5856
)
57+
from .utils import (
58+
parse_message_id_safe,
59+
safe_parse_message_id,
60+
)
5961

6062
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
6163
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
@@ -795,8 +797,8 @@ async def handle_ihave(
795797

796798
# Add all unknown message ids (ids that appear in ihave_msg but not in
797799
# seen_seqnos) to list of messages we want to request
798-
msg_ids_wanted: list[str] = [
799-
msg_id
800+
msg_ids_wanted: list[MessageID] = [
801+
parse_message_id_safe(msg_id)
800802
for msg_id in ihave_msg.messageIDs
801803
if msg_id not in seen_seqnos_and_peers
802804
]
@@ -812,9 +814,9 @@ async def handle_iwant(
812814
Forwards all request messages that are present in mcache to the
813815
requesting peer.
814816
"""
815-
# FIXME: Update type of message ID
816-
# FIXME: Find a better way to parse the msg ids
817-
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
817+
msg_ids: list[tuple[bytes, bytes]] = [
818+
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
819+
]
818820
msgs_to_forward: list[rpc_pb2.Message] = []
819821
for msg_id_iwant in msg_ids:
820822
# Check if the wanted message ID is present in mcache

libp2p/pubsub/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import ast
12
import logging
23

34
from libp2p.abc import IHost
5+
from libp2p.custom_types import (
6+
MessageID,
7+
)
48
from libp2p.peer.envelope import consume_envelope
59
from libp2p.peer.id import ID
610
from libp2p.pubsub.pb.rpc_pb2 import RPC
@@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
4852
logger.error("Failed to update the Certified-Addr-Book: %s", e)
4953
return False
5054
return True
55+
56+
57+
def parse_message_id_safe(msg_id_str: str) -> MessageID:
58+
"""Safely handle message ID as string."""
59+
return MessageID(msg_id_str)
60+
61+
62+
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
63+
"""
64+
Safely parse message ID using ast.literal_eval with validation.
65+
:param msg_id_str: String representation of message ID
66+
:return: Tuple of (seqno, from_id) as bytes
67+
:raises ValueError: If parsing fails
68+
"""
69+
try:
70+
parsed = ast.literal_eval(msg_id_str)
71+
if not isinstance(parsed, tuple) or len(parsed) != 2:
72+
raise ValueError("Invalid message ID format")
73+
74+
seqno, from_id = parsed
75+
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
76+
raise ValueError("Message ID components must be bytes")
77+
78+
return (seqno, from_id)
79+
except (ValueError, SyntaxError) as e:
80+
raise ValueError(f"Invalid message ID format: {e}")

newsfragments/843.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module.

tests/core/pubsub/test_gossipsub.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import random
2+
from unittest.mock import (
3+
AsyncMock,
4+
MagicMock,
5+
)
26

37
import pytest
48
import trio
@@ -7,6 +11,9 @@
711
PROTOCOL_ID,
812
GossipSub,
913
)
14+
from libp2p.pubsub.pb import (
15+
rpc_pb2,
16+
)
1017
from libp2p.tools.utils import (
1118
connect,
1219
)
@@ -754,3 +761,173 @@ async def test_single_host():
754761
assert connected_peers == 0, (
755762
f"Single host has {connected_peers} connections, expected 0"
756763
)
764+
765+
766+
@pytest.mark.trio
767+
async def test_handle_ihave(monkeypatch):
768+
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
769+
gossipsub_routers = []
770+
for pubsub in pubsubs_gsub:
771+
if isinstance(pubsub.router, GossipSub):
772+
gossipsub_routers.append(pubsub.router)
773+
gossipsubs = tuple(gossipsub_routers)
774+
775+
index_alice = 0
776+
index_bob = 1
777+
id_bob = pubsubs_gsub[index_bob].my_id
778+
779+
# Connect Alice and Bob
780+
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
781+
await trio.sleep(0.1) # Allow connections to establish
782+
783+
# Mock emit_iwant to capture calls
784+
mock_emit_iwant = AsyncMock()
785+
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
786+
787+
# Create a test message ID as a string representation of a (seqno, from) tuple
788+
test_seqno = b"1234"
789+
test_from = id_bob.to_bytes()
790+
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
791+
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
792+
793+
# Mock seen_messages.cache to avoid false positives
794+
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
795+
796+
# Simulate Bob sending IHAVE to Alice
797+
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
798+
799+
# Check if emit_iwant was called with the correct message ID
800+
mock_emit_iwant.assert_called_once()
801+
called_args = mock_emit_iwant.call_args[0]
802+
assert called_args[0] == [test_msg_id] # Expected message IDs
803+
assert called_args[1] == id_bob # Sender peer ID
804+
805+
806+
@pytest.mark.trio
807+
async def test_handle_iwant(monkeypatch):
808+
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
809+
gossipsub_routers = []
810+
for pubsub in pubsubs_gsub:
811+
if isinstance(pubsub.router, GossipSub):
812+
gossipsub_routers.append(pubsub.router)
813+
gossipsubs = tuple(gossipsub_routers)
814+
815+
index_alice = 0
816+
index_bob = 1
817+
id_alice = pubsubs_gsub[index_alice].my_id
818+
819+
# Connect Alice and Bob
820+
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
821+
await trio.sleep(0.1) # Allow connections to establish
822+
823+
# Mock mcache.get to return a message
824+
test_message = rpc_pb2.Message(data=b"test_data")
825+
test_seqno = b"1234"
826+
test_from = id_alice.to_bytes()
827+
828+
# ✅ Correct: use raw tuple and str() to serialize, no hex()
829+
test_msg_id = str((test_seqno, test_from))
830+
831+
mock_mcache_get = MagicMock(return_value=test_message)
832+
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
833+
834+
# Mock write_msg to capture the sent packet
835+
mock_write_msg = AsyncMock()
836+
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
837+
838+
# Simulate Alice sending IWANT to Bob
839+
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
840+
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
841+
842+
# Check if write_msg was called with the correct packet
843+
mock_write_msg.assert_called_once()
844+
packet = mock_write_msg.call_args[0][1]
845+
assert isinstance(packet, rpc_pb2.RPC)
846+
assert len(packet.publish) == 1
847+
assert packet.publish[0] == test_message
848+
849+
# Verify that mcache.get was called with the correct parsed message ID
850+
mock_mcache_get.assert_called_once()
851+
called_msg_id = mock_mcache_get.call_args[0][0]
852+
assert isinstance(called_msg_id, tuple)
853+
assert called_msg_id == (test_seqno, test_from)
854+
855+
856+
@pytest.mark.trio
857+
async def test_handle_iwant_invalid_msg_id(monkeypatch):
858+
"""
859+
Test that handle_iwant raises ValueError for malformed message IDs.
860+
"""
861+
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
862+
gossipsub_routers = []
863+
for pubsub in pubsubs_gsub:
864+
if isinstance(pubsub.router, GossipSub):
865+
gossipsub_routers.append(pubsub.router)
866+
gossipsubs = tuple(gossipsub_routers)
867+
868+
index_alice = 0
869+
index_bob = 1
870+
id_alice = pubsubs_gsub[index_alice].my_id
871+
872+
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
873+
await trio.sleep(0.1)
874+
875+
# Malformed message ID (not a tuple string)
876+
malformed_msg_id = "not_a_valid_msg_id"
877+
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
878+
879+
# Mock mcache.get and write_msg to ensure they are not called
880+
mock_mcache_get = MagicMock()
881+
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
882+
mock_write_msg = AsyncMock()
883+
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
884+
885+
with pytest.raises(ValueError):
886+
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
887+
mock_mcache_get.assert_not_called()
888+
mock_write_msg.assert_not_called()
889+
890+
# Message ID that's a tuple string but not (bytes, bytes)
891+
invalid_tuple_msg_id = "('abc', 123)"
892+
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
893+
with pytest.raises(ValueError):
894+
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
895+
mock_mcache_get.assert_not_called()
896+
mock_write_msg.assert_not_called()
897+
898+
899+
@pytest.mark.trio
900+
async def test_handle_ihave_empty_message_ids(monkeypatch):
901+
"""
902+
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
903+
"""
904+
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
905+
gossipsub_routers = []
906+
for pubsub in pubsubs_gsub:
907+
if isinstance(pubsub.router, GossipSub):
908+
gossipsub_routers.append(pubsub.router)
909+
gossipsubs = tuple(gossipsub_routers)
910+
911+
index_alice = 0
912+
index_bob = 1
913+
id_bob = pubsubs_gsub[index_bob].my_id
914+
915+
# Connect Alice and Bob
916+
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
917+
await trio.sleep(0.1) # Allow connections to establish
918+
919+
# Mock emit_iwant to capture calls
920+
mock_emit_iwant = AsyncMock()
921+
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
922+
923+
# Empty messageIDs list
924+
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
925+
926+
# Mock seen_messages.cache to avoid false positives
927+
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
928+
929+
# Simulate Bob sending IHAVE to Alice
930+
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
931+
932+
# emit_iwant should not be called since there are no message IDs
933+
mock_emit_iwant.assert_not_called()

0 commit comments

Comments
 (0)