diff --git a/examples/identify_push/identify_push_demo.py b/examples/identify_push/identify_push_demo.py index 98e1e937f..7d5baa34e 100644 --- a/examples/identify_push/identify_push_demo.py +++ b/examples/identify_push/identify_push_demo.py @@ -272,12 +272,8 @@ async def main() -> None: print("\nšŸ“¤ Host 1 pushing identify information to Host 2...") try: - success = await push_identify_to_peer(host_1, host_2.get_id()) - - if success: - print("āœ… Identify push completed successfully!") - else: - print("āš ļø Identify push didn't complete successfully") + await push_identify_to_peer(host_1, host_2.get_id()) + print("āœ… Identify push completed successfully!") except Exception as e: print(f"āŒ Error during identify push: {str(e)}") diff --git a/examples/identify_push/identify_push_listener_dialer.py b/examples/identify_push/identify_push_listener_dialer.py index 079457a22..fe783278a 100644 --- a/examples/identify_push/identify_push_listener_dialer.py +++ b/examples/identify_push/identify_push_listener_dialer.py @@ -321,23 +321,17 @@ async def run_dialer( print("\nPushing identify information to listener...") try: - # Call push_identify_to_peer which returns a boolean - success = await push_identify_to_peer( + # Call push_identify_to_peer which now returns None and raises + # exceptions + await push_identify_to_peer( host, peer_info.peer_id, use_varint_format=use_varint_format ) - if success: - logger.info("Identify push completed successfully!") - print("āœ… Identify push completed successfully!") + logger.info("Identify push completed successfully!") + print("āœ… Identify push completed successfully!") - logger.info("Example completed successfully!") - print("\nExample completed successfully!") - else: - logger.warning("Identify push didn't complete successfully.") - print("\nWarning: Identify push didn't complete successfully.") - - logger.warning("Example completed with warnings.") - print("Example completed with warnings.") + logger.info("Example completed successfully!") + print("\nExample completed successfully!") except Exception as e: error_msg = str(e) logger.error(f"Error during identify push: {error_msg}") diff --git a/libp2p/abc.py b/libp2p/abc.py index 964c74546..02e0488e2 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -253,13 +253,12 @@ async def reset(self) -> None: """ @abstractmethod - def set_deadline(self, ttl: int) -> bool: + def set_deadline(self, ttl: int) -> None: """ Set a deadline for the stream. :param ttl: Time-to-live for the stream in seconds. - :return: True if the deadline was set successfully, - otherwise False. + :raises MuxedStreamError: if setting the deadline fails. """ @abstractmethod diff --git a/libp2p/host/autonat/autonat.py b/libp2p/host/autonat/autonat.py index ae4663f12..885e4dc62 100644 --- a/libp2p/host/autonat/autonat.py +++ b/libp2p/host/autonat/autonat.py @@ -13,6 +13,9 @@ from libp2p.host.basic_host import ( BasicHost, ) +from libp2p.host.exceptions import ( + HostException, +) from libp2p.network.stream.net_stream import ( NetStream, ) @@ -154,7 +157,11 @@ async def _handle_dial(self, message: Message) -> Message: if peer_id in self.dial_results: success = self.dial_results[peer_id] else: - success = await self._try_dial(peer_id) + try: + await self._try_dial(peer_id) + success = True + except HostException: + success = False self.dial_results[peer_id] = success peer_info = PeerInfo() @@ -166,7 +173,7 @@ async def _handle_dial(self, message: Message) -> Message: response.dial_response.CopyFrom(dial_response) return response - async def _try_dial(self, peer_id: ID) -> bool: + async def _try_dial(self, peer_id: ID) -> None: """ Attempt to establish a connection with a peer. @@ -175,19 +182,17 @@ async def _try_dial(self, peer_id: ID) -> bool: peer_id : ID The identifier of the peer to attempt to dial. - Returns - ------- - bool - True if the connection was successfully established, - False if the connection attempt failed. + Raises + ------ + HostException + If the connection attempt failed. """ try: stream = await self.host.new_stream(peer_id, [AUTONAT_PROTOCOL_ID]) await stream.close() - return True - except Exception: - return False + except Exception as e: + raise HostException(f"Failed to dial peer {peer_id}: {e}") from e def get_status(self) -> int: """ diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 6b7eb1d35..bc98da5a7 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -35,6 +35,7 @@ get_default_protocols, ) from libp2p.host.exceptions import ( + HostException, StreamFailure, ) from libp2p.peer.id import ( @@ -206,8 +207,21 @@ def set_stream_handler( :param protocol_id: protocol id used on stream :param stream_handler: a stream handler function + :raises HostException: if setting the stream handler fails """ - self.multiselect.add_handler(protocol_id, stream_handler) + try: + if not protocol_id or ( + isinstance(protocol_id, str) and not protocol_id.strip() + ): + raise HostException("Protocol ID cannot be empty") + if stream_handler is None: + raise HostException("Stream handler cannot be None") + + self.multiselect.add_handler(protocol_id, stream_handler) + except HostException: + raise + except Exception as e: + raise HostException(f"Failed to set stream handler: {e}") from e async def new_stream( self, diff --git a/libp2p/identity/identify_push/identify_push.py b/libp2p/identity/identify_push/identify_push.py index 5b23851b1..b2713eed9 100644 --- a/libp2p/identity/identify_push/identify_push.py +++ b/libp2p/identity/identify_push/identify_push.py @@ -17,6 +17,9 @@ StreamHandlerFn, TProtocol, ) +from libp2p.host.exceptions import ( + HostException, +) from libp2p.network.stream.exceptions import ( StreamClosed, ) @@ -172,7 +175,7 @@ async def push_identify_to_peer( observed_multiaddr: Multiaddr | None = None, limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT), use_varint_format: bool = True, -) -> bool: +) -> None: """ Push an identify message to a specific peer. @@ -186,8 +189,8 @@ async def push_identify_to_peer( limit: Semaphore for concurrency control. use_varint_format: True=length-prefixed, False=raw protobuf. - Returns: - bool: True if the push was successful, False otherwise. + Raises: + HostException: If the identify push fails. """ async with limit: @@ -224,10 +227,42 @@ async def push_identify_to_peer( await stream.close() logger.debug("Successfully pushed identify to peer %s", peer_id) - return True except Exception as e: logger.error("Error pushing identify to peer %s: %s", peer_id, e) - return False + raise HostException( + f"Failed to push identify to peer {peer_id}: {e}" + ) from e + + +async def _safe_push_identify_to_peer( + host: IHost, + peer_id: ID, + observed_multiaddr: Multiaddr | None = None, + limit: trio.Semaphore = trio.Semaphore(CONCURRENCY_LIMIT), + use_varint_format: bool = True, +) -> None: + """ + Safely push identify information to a specific peer, catching and logging + exceptions. + + This is a wrapper around push_identify_to_peer that catches exceptions and + logs them instead of letting them propagate, which is useful when calling + from a nursery. + + Args: + host: The libp2p host. + peer_id: The peer ID to push to. + observed_multiaddr: The observed multiaddress (optional). + limit: Semaphore for concurrency control. + use_varint_format: True=length-prefixed, False=raw protobuf. + + """ + try: + await push_identify_to_peer( + host, peer_id, observed_multiaddr, limit, use_varint_format + ) + except Exception as e: + logger.debug("Failed to push identify to peer %s: %s", peer_id, e) async def push_identify_to_peers( @@ -260,7 +295,7 @@ async def push_identify_to_peers( async with trio.open_nursery() as nursery: for peer_id in peer_ids: nursery.start_soon( - push_identify_to_peer, + _safe_push_identify_to_peer, host, peer_id, observed_multiaddr, diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 150ae9dd0..5764930ac 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -12,6 +12,7 @@ ) from libp2p.stream_muxer.exceptions import ( MuxedConnUnavailable, + MuxedStreamError, ) from libp2p.stream_muxer.rw_lock import ReadWriteLock @@ -90,8 +91,13 @@ def is_initiator(self) -> bool: return self.stream_id.is_initiator async def _read_until_eof(self) -> bytes: - async for data in self.incoming_data_channel: - self._buf.extend(data) + if self.read_deadline is not None: + with trio.fail_after(self.read_deadline): + async for data in self.incoming_data_channel: + self._buf.extend(data) + else: + async for data in self.incoming_data_channel: + self._buf.extend(data) payload = self._buf self._buf = self._buf[len(payload) :] return bytes(payload) @@ -138,8 +144,16 @@ async def read(self, n: int | None = None) -> bytes: # We know `receive` will be blocked here. Wait for data here with # `receive` and catch all kinds of errors here. try: - data = await self.incoming_data_channel.receive() - self._buf.extend(data) + # Apply read deadline if set + if self.read_deadline is not None: + with trio.fail_after(self.read_deadline): + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + else: + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + except trio.TooSlowError: + raise MuxedStreamError("Read operation timed out") except trio.EndOfChannel: if self.event_reset.is_set(): raise MplexStreamReset @@ -173,7 +187,15 @@ async def write(self, data: bytes) -> None: if self.is_initiator else HeaderTags.MessageReceiver ) - await self.muxed_conn.send_message(flag, data, self.stream_id) + try: + # Apply write deadline if set + if self.write_deadline is not None: + with trio.fail_after(self.write_deadline): + await self.muxed_conn.send_message(flag, data, self.stream_id) + else: + await self.muxed_conn.send_message(flag, data, self.stream_id) + except trio.TooSlowError: + raise MuxedStreamError("Write operation timed out") async def close(self) -> None: """ @@ -241,34 +263,48 @@ async def reset(self) -> None: if self.muxed_conn.streams is not None: self.muxed_conn.streams.pop(self.stream_id, None) - # TODO deadline not in use - def set_deadline(self, ttl: int) -> bool: + def set_deadline(self, ttl: int) -> None: """ Set deadline for muxed stream. - :return: True if successful + :param ttl: Time-to-live for the stream in seconds + :raises MuxedStreamError: if setting the deadline fails """ - self.read_deadline = ttl - self.write_deadline = ttl - return True - - def set_read_deadline(self, ttl: int) -> bool: + try: + if ttl < 0: + raise ValueError("Deadline cannot be negative") + self.read_deadline = ttl + self.write_deadline = ttl + except Exception as e: + raise MuxedStreamError(f"Failed to set deadline: {e}") from e + + def set_read_deadline(self, ttl: int) -> None: """ Set read deadline for muxed stream. - :return: True if successful + :param ttl: Time-to-live for the read deadline in seconds + :raises MuxedStreamError: if setting the read deadline fails """ - self.read_deadline = ttl - return True + try: + if ttl < 0: + raise ValueError("Read deadline cannot be negative") + self.read_deadline = ttl + except Exception as e: + raise MuxedStreamError(f"Failed to set read deadline: {e}") from e - def set_write_deadline(self, ttl: int) -> bool: + def set_write_deadline(self, ttl: int) -> None: """ Set write deadline for muxed stream. - :return: True if successful + :param ttl: Time-to-live for the write deadline in seconds + :raises MuxedStreamError: if setting the write deadline fails """ - self.write_deadline = ttl - return True + try: + if ttl < 0: + raise ValueError("Write deadline cannot be negative") + self.write_deadline = ttl + except Exception as e: + raise MuxedStreamError(f"Failed to set write deadline: {e}") from e def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the parent Mplex connection.""" diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index bb84a5db6..631ddb811 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -289,15 +289,15 @@ async def reset(self) -> None: self.closed = True self.reset_received = True # Mark as reset - def set_deadline(self, ttl: int) -> bool: + def set_deadline(self, ttl: int) -> None: """ Set a deadline for the stream. Yamux does not support deadlines natively, - so this method always returns False to indicate the operation is unsupported. + so this method raises an exception to indicate the operation is unsupported. :param ttl: Time-to-live in seconds (ignored). - :return: False, as deadlines are not supported. + :raises NotImplementedError: as deadlines are not supported in Yamux. """ - raise NotImplementedError("Yamux does not support setting read deadlines") + raise NotImplementedError("Yamux does not support setting deadlines") def get_remote_address(self) -> tuple[str, int] | None: """ diff --git a/libp2p/transport/quic/stream.py b/libp2p/transport/quic/stream.py index dac8925ec..5834b58ed 100644 --- a/libp2p/transport/quic/stream.py +++ b/libp2p/transport/quic/stream.py @@ -632,15 +632,15 @@ async def __aexit__( logger.debug("Exiting the context and closing the stream") await self.close() - def set_deadline(self, ttl: int) -> bool: + def set_deadline(self, ttl: int) -> None: """ Set a deadline for the stream. QUIC does not support deadlines natively, - so this method always returns False to indicate the operation is unsupported. + so this method raises an exception to indicate the operation is unsupported. :param ttl: Time-to-live in seconds (ignored). - :return: False, as deadlines are not supported. + :raises NotImplementedError: as deadlines are not supported in QUIC. """ - raise NotImplementedError("QUIC does not support setting read deadlines") + raise NotImplementedError("QUIC does not support setting deadlines") # String representation for debugging diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 1598ea42a..8e9bf9288 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -47,7 +47,7 @@ async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: Put listener in listening mode and wait for incoming connections. :param maddr: maddr of peer - :return: return True if successful + :raises OpenConnectionError: if listening fails """ async def serve_tcp( @@ -76,17 +76,19 @@ async def handler(stream: trio.SocketStream) -> None: tcp_port_str = maddr.value_for_protocol("tcp") if tcp_port_str is None: - logger.error(f"Cannot listen: TCP port is missing in multiaddress {maddr}") - return False + error_msg = f"Cannot listen: TCP port is missing in multiaddress {maddr}" + logger.error(error_msg) + raise OpenConnectionError(error_msg) try: tcp_port = int(tcp_port_str) except ValueError: - logger.error( + error_msg = ( f"Cannot listen: Invalid TCP port '{tcp_port_str}' " f"in multiaddress {maddr}" ) - return False + logger.error(error_msg) + raise OpenConnectionError(error_msg) ip4_host_str = maddr.value_for_protocol("ip4") # For trio.serve_tcp, ip4_host_str (as host argument) can be None, @@ -102,13 +104,14 @@ async def handler(stream: trio.SocketStream) -> None: if started_listeners is None: # This implies that task_status.started() was not called within serve_tcp, # likely because trio.serve_tcp itself failed to start (e.g., port in use). - logger.error( + error_msg = ( f"Failed to start TCP listener for {maddr}: " f"`nursery.start` returned None. " "This might be due to issues like the port already " "being in use or invalid host." ) - return False + logger.error(error_msg) + raise OpenConnectionError(error_msg) self.listeners.extend(started_listeners) return True diff --git a/newsfragments/194.bugfix.rst b/newsfragments/194.bugfix.rst new file mode 100644 index 000000000..d198f9841 --- /dev/null +++ b/newsfragments/194.bugfix.rst @@ -0,0 +1,6 @@ +Replaced return True/False pattern with try/catch pattern for better error handling. + +- Updated set_stream_handler method in BasicHost to raise HostException instead of returning boolean +- Enhanced exception handling in mplex stream deadline methods with proper error chaining +- Improved identify/push protocol error handling with comprehensive exception management +- Added proper exception documentation and error messages throughout affected methods diff --git a/tests/core/examples/test_examples.py b/tests/core/examples/test_examples.py index d60327b68..be82c18a0 100644 --- a/tests/core/examples/test_examples.py +++ b/tests/core/examples/test_examples.py @@ -281,8 +281,11 @@ async def identify_push_demo(host_a, host_b): logger.debug("Host A protocols before push: %s", host_a_protocols) # Push identify information from host_a to host_b - success = await push_identify_to_peer(host_a, host_b.get_id()) - assert success is True + try: + await push_identify_to_peer(host_a, host_b.get_id()) + # If we get here, the push was successful + except Exception as e: + pytest.fail(f"push_identify_to_peer failed: {e}") # Add a small delay to allow processing await trio.sleep(0.1) diff --git a/tests/core/host/test_autonat.py b/tests/core/host/test_autonat.py index 4c6dbacab..6934fe072 100644 --- a/tests/core/host/test_autonat.py +++ b/tests/core/host/test_autonat.py @@ -18,6 +18,9 @@ Status, Type, ) +from libp2p.host.exceptions import ( + HostException, +) from libp2p.network.stream.exceptions import ( StreamError, ) @@ -108,9 +111,8 @@ async def test_try_dial(): mock_stream = AsyncMock(spec=NetStream) mock_new_stream.return_value = mock_stream - result = await service._try_dial(peer_id) - - assert result is True + # Should not raise an exception for successful dial + await service._try_dial(peer_id) mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID]) mock_stream.close.assert_called_once() @@ -120,9 +122,9 @@ async def test_try_dial(): ) as mock_new_stream: mock_new_stream.side_effect = Exception("Connection failed") - result = await service._try_dial(peer_id) - - assert result is False + # Should raise HostException for failed dial + with pytest.raises(HostException, match="Failed to dial peer"): + await service._try_dial(peer_id) mock_new_stream.assert_called_once_with(peer_id, [AUTONAT_PROTOCOL_ID]) diff --git a/tests/core/host/test_basic_host.py b/tests/core/host/test_basic_host.py index 635f28632..2531104e8 100644 --- a/tests/core/host/test_basic_host.py +++ b/tests/core/host/test_basic_host.py @@ -11,6 +11,9 @@ from libp2p.crypto.rsa import ( create_new_key_pair, ) +from libp2p.custom_types import ( + TProtocol, +) from libp2p.host.basic_host import ( BasicHost, ) @@ -18,6 +21,7 @@ get_default_protocols, ) from libp2p.host.exceptions import ( + HostException, StreamFailure, ) @@ -59,3 +63,154 @@ async def fake_negotiate(comm, timeout): # Ensure reset was called since negotiation failed net_stream.reset.assert_awaited() + + +def test_set_stream_handler_success(): + """Test successful stream handler setting.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + protocol = TProtocol("/test/protocol") + host.set_stream_handler(protocol, mock_handler) + + assert protocol in host.multiselect.handlers + assert host.multiselect.handlers[protocol] == mock_handler + + +def test_set_stream_handler_empty_protocol(): + """Test set_stream_handler raises exception when protocol_id is empty.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + with pytest.raises(HostException, match="Protocol ID cannot be empty"): + host.set_stream_handler(TProtocol(""), mock_handler) + + +def test_set_stream_handler_none_handler(): + """Test set_stream_handler raises exception when stream_handler is None.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + with pytest.raises(HostException, match="Stream handler cannot be None"): + host.set_stream_handler(TProtocol("/test/protocol"), None) # type: ignore + + +def test_set_stream_handler_exception_handling(): + """Test set_stream_handler properly handles exceptions from multiselect.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + original_add_handler = host.multiselect.add_handler + host.multiselect.add_handler = MagicMock(side_effect=RuntimeError("Test error")) + + with pytest.raises(HostException, match="Failed to set stream handler"): + host.set_stream_handler(TProtocol("/test/protocol"), mock_handler) + + host.multiselect.add_handler = original_add_handler + + +def test_set_stream_handler_multiple_exceptions(): + """Test set_stream_handler handles different types of exceptions.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + # Test with ValueError + original_add_handler = host.multiselect.add_handler + host.multiselect.add_handler = MagicMock(side_effect=ValueError("Invalid value")) + + with pytest.raises(HostException, match="Failed to set stream handler"): + host.set_stream_handler(TProtocol("/test/protocol"), mock_handler) + + # Test with KeyError + host.multiselect.add_handler = MagicMock(side_effect=KeyError("Missing key")) + + with pytest.raises(HostException, match="Failed to set stream handler"): + host.set_stream_handler(TProtocol("/test/protocol"), mock_handler) + + host.multiselect.add_handler = original_add_handler + + +def test_set_stream_handler_preserves_exception_chain(): + """Test that set_stream_handler preserves the original exception chain.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + original_add_handler = host.multiselect.add_handler + original_error = RuntimeError("Original error") + host.multiselect.add_handler = MagicMock(side_effect=original_error) + + with pytest.raises(HostException) as exc_info: + host.set_stream_handler(TProtocol("/test/protocol"), mock_handler) + + # Check that the original exception is preserved in the chain + assert exc_info.value.__cause__ is original_error + assert "Failed to set stream handler" in str(exc_info.value) + + host.multiselect.add_handler = original_add_handler + + +def test_set_stream_handler_success_with_valid_inputs(): + """Test set_stream_handler succeeds with various valid protocol IDs.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + # Test with different valid protocol IDs + valid_protocols = [ + TProtocol("/test/protocol"), + TProtocol("/ipfs/id/1.0.0"), + TProtocol("/libp2p/autonat/1.0.0"), + TProtocol("/multistream/1.0.0"), + TProtocol("/test/protocol/with/version/1.0.0"), + ] + + for protocol_id in valid_protocols: + host.set_stream_handler(protocol_id, mock_handler) + assert protocol_id in host.multiselect.handlers + assert host.multiselect.handlers[protocol_id] == mock_handler + + +def test_set_stream_handler_edge_cases(): + """Test set_stream_handler with edge case inputs.""" + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + async def mock_handler(stream): + pass + + # Test with whitespace-only protocol ID + with pytest.raises(HostException, match="Protocol ID cannot be empty"): + host.set_stream_handler(TProtocol(" "), mock_handler) + + # Test with None protocol ID + with pytest.raises(HostException, match="Protocol ID cannot be empty"): + host.set_stream_handler(None, mock_handler) # type: ignore + + # Test with empty string protocol ID + with pytest.raises(HostException, match="Protocol ID cannot be empty"): + host.set_stream_handler(TProtocol(""), mock_handler) diff --git a/tests/core/identity/identify_push/test_identify_push.py b/tests/core/identity/identify_push/test_identify_push.py index a1e2e472f..f66978926 100644 --- a/tests/core/identity/identify_push/test_identify_push.py +++ b/tests/core/identity/identify_push/test_identify_push.py @@ -619,8 +619,11 @@ async def test_identify_push_default_varint_format(security_protocol): host_b.set_stream_handler(ID_PUSH, identify_push_handler_for(host_b)) # Push identify information from host_a to host_b using default settings - success = await push_identify_to_peer(host_a, host_b.get_id()) - assert success, "Identify push should succeed with default varint format" + try: + await push_identify_to_peer(host_a, host_b.get_id()) + # If we get here, the push was successful + except Exception as e: + pytest.fail(f"Identify push should succeed with default varint format: {e}") # Wait a bit for the push to complete await trio.sleep(0.1) @@ -669,10 +672,13 @@ async def test_identify_push_legacy_raw_format(security_protocol): ) # Push identify information from host_a to host_b using legacy format - success = await push_identify_to_peer( - host_a, host_b.get_id(), use_varint_format=False - ) - assert success, "Identify push should succeed with legacy raw format" + try: + await push_identify_to_peer( + host_a, host_b.get_id(), use_varint_format=False + ) + # If we get here, the push was successful + except Exception as e: + pytest.fail(f"Identify push should succeed with legacy raw format: {e}") # Wait a bit for the push to complete await trio.sleep(0.1) diff --git a/tests/core/identity/identify_push/test_identify_push_exception_handling.py b/tests/core/identity/identify_push/test_identify_push_exception_handling.py new file mode 100644 index 000000000..a008de08e --- /dev/null +++ b/tests/core/identity/identify_push/test_identify_push_exception_handling.py @@ -0,0 +1,131 @@ +import pytest + +from libp2p.host.exceptions import ( + HostException, +) +from libp2p.identity.identify.pb.identify_pb2 import ( + Identify, +) +from libp2p.identity.identify_push.identify_push import ( + _update_peerstore_from_identify, + push_identify_to_peer, + push_identify_to_peers, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) +from tests.utils.factories import ( + host_pair_factory, +) + + +@pytest.mark.trio +async def test_push_identify_to_peer_exception_handling(): + """Test that push_identify_to_peer properly handles exceptions.""" + async with host_pair_factory() as (host_a, host_b): + # Connect the hosts + peer_info = PeerInfo(host_b.get_id(), host_b.get_addrs()) + await host_a.connect(peer_info) + + # Mock new_stream to raise an exception + original_new_stream = host_a.new_stream + + async def mock_new_stream(*args, **kwargs): + raise RuntimeError("Mock stream creation error") + + host_a.new_stream = mock_new_stream # type: ignore # type: ignore + + # Test that push_identify_to_peer raises HostException + with pytest.raises(HostException, match="Failed to push identify to peer"): + await push_identify_to_peer(host_a, host_b.get_id()) + + # Restore original method + host_a.new_stream = original_new_stream + + +@pytest.mark.trio +async def test_push_identify_to_peer_preserves_exception_chain(): + """Test that push_identify_to_peer preserves the original exception chain.""" + async with host_pair_factory() as (host_a, host_b): + # Connect the hosts + peer_info = PeerInfo(host_b.get_id(), host_b.get_addrs()) + await host_a.connect(peer_info) + + original_error = ConnectionError("Original connection error") + + # Mock new_stream to raise a specific exception + async def mock_new_stream(*args, **kwargs): + raise original_error + + host_a.new_stream = mock_new_stream # type: ignore + + # Test that the original exception is preserved in the chain + with pytest.raises(HostException) as exc_info: + await push_identify_to_peer(host_a, host_b.get_id()) + + # Check that the original exception is preserved in the chain + assert exc_info.value.__cause__ is original_error + assert "Failed to push identify to peer" in str(exc_info.value) + + +@pytest.mark.trio +async def test_update_peerstore_from_identify_exception_handling(): + """Test that _update_peerstore_from_identify handles exceptions gracefully.""" + async with host_pair_factory() as (host_a, host_b): + # Create a mock peerstore that raises exceptions + class MockPeerstore: + def add_protocols(self, peer_id, protocols): + raise ValueError("Mock protocol error") + + def add_pubkey(self, peer_id, pubkey): + raise KeyError("Mock pubkey error") + + def add_addr(self, peer_id, addr, ttl): + raise RuntimeError("Mock addr error") + + def consume_peer_record(self, envelope, ttl): + raise ConnectionError("Mock record error") + + mock_peerstore = MockPeerstore() + + # Create an identify message with various fields + identify_msg = Identify() + identify_msg.public_key = b"mock_public_key" + identify_msg.listen_addrs.extend([b"/ip4/127.0.0.1/tcp/4001"]) + identify_msg.protocols.extend(["/test/protocol/1.0.0"]) + identify_msg.observed_addr = b"/ip4/127.0.0.1/tcp/4002" + identify_msg.signedPeerRecord = b"mock_signed_record" + + # Test that the function handles exceptions gracefully + # It should not raise any exceptions, just log errors + await _update_peerstore_from_identify( + mock_peerstore, # type: ignore + host_b.get_id(), + identify_msg, + ) + + # If we get here without exceptions, the test passes + + +@pytest.mark.trio +async def test_push_identify_to_peers_exception_handling(): + """Test that push_identify_to_peers handles exceptions gracefully.""" + async with host_pair_factory() as (host_a, host_b): + # Connect the hosts + peer_info = PeerInfo(host_b.get_id(), host_b.get_addrs()) + await host_a.connect(peer_info) + + # Mock new_stream to raise an exception for one peer + original_new_stream = host_a.new_stream + + async def mock_new_stream(*args, **kwargs): + raise ConnectionError("Mock connection error") + + host_a.new_stream = mock_new_stream # type: ignore + + # Test that push_identify_to_peers handles exceptions gracefully + # It should not raise any exceptions, just log errors + await push_identify_to_peers(host_a, {host_b.get_id()}) + + # Restore original method + host_a.new_stream = original_new_stream diff --git a/tests/core/identity/identify_push/test_identify_push_integration.py b/tests/core/identity/identify_push/test_identify_push_integration.py index 9ee38b10c..d09498731 100644 --- a/tests/core/identity/identify_push/test_identify_push_integration.py +++ b/tests/core/identity/identify_push/test_identify_push_integration.py @@ -150,13 +150,12 @@ async def wrapped_handler(stream): host_a.set_stream_handler(ID_PUSH, wrapped_handler) # Host B pushes with varint format (should fail gracefully) - success = await push_identify_to_peer( - host_b, host_a.get_id(), use_varint_format=True - ) - # This should fail due to format mismatch - # Note: The format detection might be more robust than expected - # so we just check that the operation completes - assert isinstance(success, bool) + try: + await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True) + # If we get here, the operation succeeded (unexpected but acceptable) + except Exception: + # This is expected due to format mismatch, which is fine + pass @pytest.mark.trio @@ -182,13 +181,14 @@ async def wrapped_handler(stream): host_a.set_stream_handler(ID_PUSH, wrapped_handler) # Host B pushes with raw format (should fail gracefully) - success = await push_identify_to_peer( - host_b, host_a.get_id(), use_varint_format=False - ) - # This should fail due to format mismatch - # Note: The format detection might be more robust than expected - # so we just check that the operation completes - assert isinstance(success, bool) + try: + await push_identify_to_peer( + host_b, host_a.get_id(), use_varint_format=False + ) + # If we get here, the operation succeeded (unexpected but acceptable) + except Exception: + # This is expected due to format mismatch, which is fine + pass @pytest.mark.trio @@ -280,10 +280,11 @@ async def dummy_handler(stream): ) # Push identify information from host_b to host_a - success = await push_identify_to_peer( - host_b, host_a.get_id(), use_varint_format=True - ) - assert success + try: + await push_identify_to_peer(host_b, host_a.get_id(), use_varint_format=True) + # If we get here, the push was successful + except Exception as e: + pytest.fail(f"Identify push should succeed with large message: {e}") # Wait a bit for the push to complete await trio.sleep(0.1) @@ -359,8 +360,11 @@ async def dummy_handler(stream): results = [] async def push_identify(): - result = await push_identify_to_peer(host_b, host_a.get_id()) - results.append(result) + try: + await push_identify_to_peer(host_b, host_a.get_id()) + results.append(True) # Success + except Exception: + results.append(False) # Failure # Run multiple concurrent pushes using nursery async with trio.open_nursery() as nursery: @@ -400,8 +404,11 @@ async def dummy_handler(stream): host_a.set_stream_handler(ID_PUSH, identify_push_handler_for(host_a)) # Push identify information from host_b to host_a - success = await push_identify_to_peer(host_b, host_a.get_id()) - assert success + try: + await push_identify_to_peer(host_b, host_a.get_id()) + # If we get here, the push was successful + except Exception as e: + pytest.fail(f"Identify push should succeed with stream handling: {e}") # Wait a bit for the push to complete await trio.sleep(0.1) @@ -434,8 +441,11 @@ async def error_handler(stream): host_a.set_stream_handler(ID_PUSH, error_handler) # Push should complete (message sent) but handler should fail gracefully - success = await push_identify_to_peer(host_b, host_a.get_id()) - assert success # The push operation itself succeeds (message sent) + try: + await push_identify_to_peer(host_b, host_a.get_id()) + # If we get here, the push was successful + except Exception as e: + pytest.fail(f"Identify push should succeed even with error handler: {e}") # Wait a bit for the handler to process await trio.sleep(0.1) diff --git a/tests/core/stream_muxer/test_mplex_stream.py b/tests/core/stream_muxer/test_mplex_stream.py index 1d9c22340..4ca237ac1 100644 --- a/tests/core/stream_muxer/test_mplex_stream.py +++ b/tests/core/stream_muxer/test_mplex_stream.py @@ -4,6 +4,9 @@ wait_all_tasks_blocked, ) +from libp2p.stream_muxer.exceptions import ( + MuxedStreamError, +) from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamClosed, MplexStreamEOF, @@ -250,3 +253,104 @@ def raise_unavailable(*args, **kwargs): with pytest.raises(RuntimeError, match="Failed to send close message"): await stream_0.close() + + +@pytest.mark.trio +async def test_mplex_stream_set_deadline_success(mplex_stream_pair): + """Test successful deadline setting.""" + stream_0, _ = mplex_stream_pair + + # Test setting deadline + stream_0.set_deadline(30) + assert stream_0.read_deadline == 30 + assert stream_0.write_deadline == 30 + + +@pytest.mark.trio +async def test_mplex_stream_set_read_deadline_success(mplex_stream_pair): + """Test successful read deadline setting.""" + stream_0, _ = mplex_stream_pair + + # Test setting read deadline + stream_0.set_read_deadline(15) + assert stream_0.read_deadline == 15 + + +@pytest.mark.trio +async def test_mplex_stream_set_write_deadline_success(mplex_stream_pair): + """Test successful write deadline setting.""" + stream_0, _ = mplex_stream_pair + + # Test setting write deadline + stream_0.set_write_deadline(20) + assert stream_0.write_deadline == 20 + + +@pytest.mark.trio +async def test_mplex_stream_set_deadline_exception_handling( + monkeypatch, mplex_stream_pair +): + """Test deadline setting handles exceptions properly.""" + stream_0, _ = mplex_stream_pair + + # Test with negative deadline value to trigger validation error + original_read_deadline = stream_0.read_deadline + original_write_deadline = stream_0.write_deadline + + with pytest.raises(MuxedStreamError, match="Failed to set deadline"): + stream_0.set_deadline(-1) + + # Verify original values are preserved + assert stream_0.read_deadline == original_read_deadline + assert stream_0.write_deadline == original_write_deadline + + +@pytest.mark.trio +async def test_mplex_stream_set_read_deadline_exception_handling( + monkeypatch, mplex_stream_pair +): + """Test read deadline setting handles exceptions properly.""" + stream_0, _ = mplex_stream_pair + + original_read_deadline = stream_0.read_deadline + + # Test with negative deadline value to trigger validation error + with pytest.raises(MuxedStreamError, match="Failed to set read deadline"): + stream_0.set_read_deadline(-1) + + # Verify original value is preserved + assert stream_0.read_deadline == original_read_deadline + + +@pytest.mark.trio +async def test_mplex_stream_set_write_deadline_exception_handling( + monkeypatch, mplex_stream_pair +): + """Test write deadline setting handles exceptions properly.""" + stream_0, _ = mplex_stream_pair + + original_write_deadline = stream_0.write_deadline + + # Test with negative deadline value to trigger validation error + with pytest.raises(MuxedStreamError, match="Failed to set write deadline"): + stream_0.set_write_deadline(-1) + + # Verify original value is preserved + assert stream_0.write_deadline == original_write_deadline + + +@pytest.mark.trio +async def test_mplex_stream_deadline_preserves_exception_chain( + monkeypatch, mplex_stream_pair +): + """Test that deadline methods preserve the original exception chain.""" + stream_0, _ = mplex_stream_pair + + # Test with negative deadline value to trigger validation error + with pytest.raises(MuxedStreamError) as exc_info: + stream_0.set_deadline(-1) + + # Check that the original exception is preserved in the chain + assert isinstance(exc_info.value.__cause__, ValueError) + assert "Deadline cannot be negative" in str(exc_info.value.__cause__) + assert "Failed to set deadline" in str(exc_info.value) diff --git a/tests/core/stream_muxer/test_yamux.py b/tests/core/stream_muxer/test_yamux.py index 444288518..77de47c31 100644 --- a/tests/core/stream_muxer/test_yamux.py +++ b/tests/core/stream_muxer/test_yamux.py @@ -258,7 +258,7 @@ async def test_yamux_deadlines_raise_not_implemented(yamux_pair): stream = await client_yamux.open_stream() with trio.move_on_after(2): with pytest.raises( - NotImplementedError, match="Yamux does not support setting read deadlines" + NotImplementedError, match="Yamux does not support setting deadlines" ): stream.set_deadline(60) logging.debug("test_yamux_deadlines_raise_not_implemented complete") diff --git a/tests/interop/go_libp2p/hole_punching/go_node/hole-punch-client b/tests/interop/go_libp2p/hole_punching/go_node/hole-punch-client new file mode 100755 index 000000000..0d544882f Binary files /dev/null and b/tests/interop/go_libp2p/hole_punching/go_node/hole-punch-client differ diff --git a/tests/interop/go_libp2p/hole_punching/go_node/hole-punch-server b/tests/interop/go_libp2p/hole_punching/go_node/hole-punch-server new file mode 100755 index 000000000..60e83f48d Binary files /dev/null and b/tests/interop/go_libp2p/hole_punching/go_node/hole-punch-server differ diff --git a/tests/interop/go_libp2p/hole_punching/go_node/relay-server b/tests/interop/go_libp2p/hole_punching/go_node/relay-server new file mode 100755 index 000000000..514fbb5a7 Binary files /dev/null and b/tests/interop/go_libp2p/hole_punching/go_node/relay-server differ