Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 88 additions & 6 deletions libp2p/stream_muxer/mplex/mplex_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import time
from types import (
TracebackType,
)
Expand Down Expand Up @@ -100,6 +101,12 @@ async def write_lock(self) -> AsyncGenerator[None, None]:
self.release_write()


class MplexStreamTimeout(Exception):
"""Raised when a stream operation exceeds its deadline."""

pass


class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
Expand All @@ -111,8 +118,8 @@ class MplexStream(IMuxedStream):
# class of IMuxedConn. Ignoring this type assignment should not pose
# any risk.
muxed_conn: "Mplex" # type: ignore[assignment]
read_deadline: int | None
write_deadline: int | None
read_deadline: float | None
write_deadline: float | None

rw_lock: ReadWriteLock
close_lock: trio.Lock
Expand Down Expand Up @@ -156,6 +163,30 @@ def __init__(
def is_initiator(self) -> bool:
return self.stream_id.is_initiator

def _check_read_deadline(self) -> None:
"""Check if read deadline has expired and raise timeout if needed."""
if self.read_deadline is not None and time.time() > self.read_deadline:
raise MplexStreamTimeout("Read operation exceeded deadline")

def _check_write_deadline(self) -> None:
"""Check if write deadline has expired and raise timeout if needed."""
if self.write_deadline is not None and time.time() > self.write_deadline:
raise MplexStreamTimeout("Write operation exceeded deadline")

def _get_read_timeout(self) -> float | None:
"""Calculate remaining time until read deadline."""
if self.read_deadline is None:
return None
remaining = self.read_deadline - time.time()
return max(0.0, remaining) if remaining > 0 else 0

def _get_write_timeout(self) -> float | None:
"""Calculate remaining time until write deadline."""
if self.write_deadline is None:
return None
remaining = self.write_deadline - time.time()
return max(0.0, remaining) if remaining > 0 else 0

async def _read_until_eof(self) -> bytes:
async for data in self.incoming_data_channel:
self._buf.extend(data)
Expand All @@ -182,6 +213,9 @@ async def read(self, n: int | None = None) -> bytes:
:param n: number of bytes to read
:return: bytes actually read
"""
# check deadline before starting
self._check_read_deadline()

async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
Expand All @@ -192,8 +226,13 @@ async def read(self, n: int | None = None) -> bytes:
raise MplexStreamReset
if n is None:
return await self._read_until_eof()

# check deadline again before potentially blocking operation
self._check_read_deadline()

if len(self._buf) == 0:
data: bytes
timeout = self._get_read_timeout()
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
Expand All @@ -207,6 +246,20 @@ async def read(self, n: int | None = None) -> bytes:
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
if timeout is not None and timeout <= 0:
raise MplexStreamTimeout(
"Read deadline exceeded while waiting for data"
)

if timeout is not None:
with trio.fail_after(timeout):
data = await self.incoming_data_channel.receive()
else:
data = await self.incoming_data_channel.receive()

self._buf.extend(data)
except trio.TooSlowError:
raise MplexStreamTimeout("Read operation timed out")
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
Expand All @@ -226,15 +279,43 @@ async def read(self, n: int | None = None) -> bytes:
self._buf = self._buf[len(payload) :]
return bytes(payload)

async def _read_until_eof_with_timeout(self) -> bytes:
"""Read until EOF with timeout support."""
timeout = self._get_read_timeout()

try:
if timeout is not None:
with trio.fail_after(timeout):
async for data in self.incoming_data_channel:
self._buf.extend(data)
else:
async for data in self.incoming_data_channel:
self._buf.extend(data)
except trio.TooSlowError:
raise MplexStreamTimeout("Read until EOF operation timed out")

payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)

async def write(self, data: bytes) -> None:
"""
Write to stream.

:return: number of bytes written
"""
# Check deadline before starting
self._check_write_deadline()

async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")

# Check deadline again after acquiring lock
timeout = self._get_write_timeout()
if timeout is not None and timeout <= 0:
raise MplexStreamTimeout("Write deadline exceeded")

flag = (
HeaderTags.MessageInitiator
if self.is_initiator
Expand Down Expand Up @@ -315,8 +396,9 @@ def set_deadline(self, ttl: int) -> bool:

:return: True if successful
"""
self.read_deadline = ttl
self.write_deadline = ttl
deadline = time.time() + ttl
self.read_deadline = deadline
self.write_deadline = deadline
return True

def set_read_deadline(self, ttl: int) -> bool:
Expand All @@ -325,7 +407,7 @@ def set_read_deadline(self, ttl: int) -> bool:

:return: True if successful
"""
self.read_deadline = ttl
self.read_deadline = time.time() + ttl
return True

def set_write_deadline(self, ttl: int) -> bool:
Expand All @@ -334,7 +416,7 @@ def set_write_deadline(self, ttl: int) -> bool:

:return: True if successful
"""
self.write_deadline = ttl
self.write_deadline = ttl + time.time()
return True

def get_remote_address(self) -> tuple[str, int] | None:
Expand Down
Loading