diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e8d0561d4..84b4cf4f6 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,5 +1,6 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +import time from types import ( TracebackType, ) @@ -100,6 +101,12 @@ async def write_lock(self) -> AsyncGenerator[None, None]: self.release_write() +class MplexStreamTimeout(Exception): + """Raised when a stream operation exceeds its deadline.""" + + pass + + class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go @@ -111,8 +118,8 @@ class MplexStream(IMuxedStream): # class of IMuxedConn. Ignoring this type assignment should not pose # any risk. muxed_conn: "Mplex" # type: ignore[assignment] - read_deadline: int | None - write_deadline: int | None + read_deadline: float | None + write_deadline: float | None rw_lock: ReadWriteLock close_lock: trio.Lock @@ -156,6 +163,30 @@ def __init__( def is_initiator(self) -> bool: return self.stream_id.is_initiator + def _check_read_deadline(self) -> None: + """Check if read deadline has expired and raise timeout if needed.""" + if self.read_deadline is not None and time.time() > self.read_deadline: + raise MplexStreamTimeout("Read operation exceeded deadline") + + def _check_write_deadline(self) -> None: + """Check if write deadline has expired and raise timeout if needed.""" + if self.write_deadline is not None and time.time() > self.write_deadline: + raise MplexStreamTimeout("Write operation exceeded deadline") + + def _get_read_timeout(self) -> float | None: + """Calculate remaining time until read deadline.""" + if self.read_deadline is None: + return None + remaining = self.read_deadline - time.time() + return max(0.0, remaining) if remaining > 0 else 0 + + def _get_write_timeout(self) -> float | None: + """Calculate remaining time until write deadline.""" + if self.write_deadline is None: + return None + remaining = self.write_deadline - time.time() + return max(0.0, remaining) if remaining > 0 else 0 + async def _read_until_eof(self) -> bytes: async for data in self.incoming_data_channel: self._buf.extend(data) @@ -182,6 +213,9 @@ async def read(self, n: int | None = None) -> bytes: :param n: number of bytes to read :return: bytes actually read """ + # check deadline before starting + self._check_read_deadline() + async with self.rw_lock.read_lock(): if n is not None and n < 0: raise ValueError( @@ -192,8 +226,13 @@ async def read(self, n: int | None = None) -> bytes: raise MplexStreamReset if n is None: return await self._read_until_eof() + + # check deadline again before potentially blocking operation + self._check_read_deadline() + if len(self._buf) == 0: data: bytes + timeout = self._get_read_timeout() # Peek whether there is data available. If yes, we just read until # there is no data, then return. try: @@ -207,6 +246,20 @@ async def read(self, n: int | None = None) -> bytes: try: data = await self.incoming_data_channel.receive() self._buf.extend(data) + if timeout is not None and timeout <= 0: + raise MplexStreamTimeout( + "Read deadline exceeded while waiting for data" + ) + + if timeout is not None: + with trio.fail_after(timeout): + data = await self.incoming_data_channel.receive() + else: + data = await self.incoming_data_channel.receive() + + self._buf.extend(data) + except trio.TooSlowError: + raise MplexStreamTimeout("Read operation timed out") except trio.EndOfChannel: if self.event_reset.is_set(): raise MplexStreamReset @@ -226,15 +279,43 @@ async def read(self, n: int | None = None) -> bytes: self._buf = self._buf[len(payload) :] return bytes(payload) + async def _read_until_eof_with_timeout(self) -> bytes: + """Read until EOF with timeout support.""" + timeout = self._get_read_timeout() + + try: + if timeout is not None: + with trio.fail_after(timeout): + async for data in self.incoming_data_channel: + self._buf.extend(data) + else: + async for data in self.incoming_data_channel: + self._buf.extend(data) + except trio.TooSlowError: + raise MplexStreamTimeout("Read until EOF operation timed out") + + payload = self._buf + self._buf = self._buf[len(payload) :] + return bytes(payload) + async def write(self, data: bytes) -> None: """ Write to stream. :return: number of bytes written """ + # Check deadline before starting + self._check_write_deadline() + async with self.rw_lock.write_lock(): if self.event_local_closed.is_set(): raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") + + # Check deadline again after acquiring lock + timeout = self._get_write_timeout() + if timeout is not None and timeout <= 0: + raise MplexStreamTimeout("Write deadline exceeded") + flag = ( HeaderTags.MessageInitiator if self.is_initiator @@ -315,8 +396,9 @@ def set_deadline(self, ttl: int) -> bool: :return: True if successful """ - self.read_deadline = ttl - self.write_deadline = ttl + deadline = time.time() + ttl + self.read_deadline = deadline + self.write_deadline = deadline return True def set_read_deadline(self, ttl: int) -> bool: @@ -325,7 +407,7 @@ def set_read_deadline(self, ttl: int) -> bool: :return: True if successful """ - self.read_deadline = ttl + self.read_deadline = time.time() + ttl return True def set_write_deadline(self, ttl: int) -> bool: @@ -334,7 +416,7 @@ def set_write_deadline(self, ttl: int) -> bool: :return: True if successful """ - self.write_deadline = ttl + self.write_deadline = ttl + time.time() return True def get_remote_address(self) -> tuple[str, int] | None: