From c5a2836829b78ae25533fc4e09dbc40df4c4d800 Mon Sep 17 00:00:00 2001 From: Michael Eze Date: Thu, 4 Sep 2025 10:51:52 +0100 Subject: [PATCH] stream_muxer(yamux): add ReadWriteLock to YamuxStream to prevent concurrent read/write corruption Introduce a read/write lock abstraction and integrate it into `YamuxStream` so that simultaneous reads and writes do not interleave, eliminating potential data corruption and race conditions. Major changes: - Abstract `ReadWriteLock` into its own util module - Integrate locking into YamuxStream for `write` operations - Ensure tests pass for lock correctness - Fix lint & type issues discovered during review Closes #793 --- libp2p/stream_muxer/mplex/mplex_stream.py | 69 +---------- libp2p/stream_muxer/rw_lock.py | 70 ++++++++++++ libp2p/stream_muxer/yamux/yamux.py | 133 ++++++++++++---------- newsfragments/897.bugfix.rst | 6 + tests/conftest.py | 3 +- 5 files changed, 148 insertions(+), 133 deletions(-) create mode 100644 libp2p/stream_muxer/rw_lock.py create mode 100644 newsfragments/897.bugfix.rst diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e8d0561d4..150ae9dd0 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,5 +1,3 @@ -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager from types import ( TracebackType, ) @@ -15,6 +13,7 @@ from libp2p.stream_muxer.exceptions import ( MuxedConnUnavailable, ) +from libp2p.stream_muxer.rw_lock import ReadWriteLock from .constants import ( HeaderTags, @@ -34,72 +33,6 @@ ) -class ReadWriteLock: - """ - A read-write lock that allows multiple concurrent readers - or one exclusive writer, implemented using Trio primitives. - """ - - def __init__(self) -> None: - self._readers = 0 - self._readers_lock = trio.Lock() # Protects access to _readers count - self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time - - async def acquire_read(self) -> None: - """Acquire a read lock. Multiple readers can hold it simultaneously.""" - try: - async with self._readers_lock: - if self._readers == 0: - await self._writer_lock.acquire() - self._readers += 1 - except trio.Cancelled: - raise - - async def release_read(self) -> None: - """Release a read lock.""" - async with self._readers_lock: - if self._readers == 1: - self._writer_lock.release() - self._readers -= 1 - - async def acquire_write(self) -> None: - """Acquire an exclusive write lock.""" - try: - await self._writer_lock.acquire() - except trio.Cancelled: - raise - - def release_write(self) -> None: - """Release the exclusive write lock.""" - self._writer_lock.release() - - @asynccontextmanager - async def read_lock(self) -> AsyncGenerator[None, None]: - """Context manager for acquiring and releasing a read lock safely.""" - acquire = False - try: - await self.acquire_read() - acquire = True - yield - finally: - if acquire: - with trio.CancelScope() as scope: - scope.shield = True - await self.release_read() - - @asynccontextmanager - async def write_lock(self) -> AsyncGenerator[None, None]: - """Context manager for acquiring and releasing a write lock safely.""" - acquire = False - try: - await self.acquire_write() - acquire = True - yield - finally: - if acquire: - self.release_write() - - class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go diff --git a/libp2p/stream_muxer/rw_lock.py b/libp2p/stream_muxer/rw_lock.py new file mode 100644 index 000000000..7910a1449 --- /dev/null +++ b/libp2p/stream_muxer/rw_lock.py @@ -0,0 +1,70 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import trio + + +class ReadWriteLock: + """ + A read-write lock that allows multiple concurrent readers + or one exclusive writer, implemented using Trio primitives. + """ + + def __init__(self) -> None: + self._readers = 0 + self._readers_lock = trio.Lock() # Protects access to _readers count + self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time + + async def acquire_read(self) -> None: + """Acquire a read lock. Multiple readers can hold it simultaneously.""" + try: + async with self._readers_lock: + if self._readers == 0: + await self._writer_lock.acquire() + self._readers += 1 + except trio.Cancelled: + raise + + async def release_read(self) -> None: + """Release a read lock.""" + async with self._readers_lock: + if self._readers == 1: + self._writer_lock.release() + self._readers -= 1 + + async def acquire_write(self) -> None: + """Acquire an exclusive write lock.""" + try: + await self._writer_lock.acquire() + except trio.Cancelled: + raise + + def release_write(self) -> None: + """Release the exclusive write lock.""" + self._writer_lock.release() + + @asynccontextmanager + async def read_lock(self) -> AsyncGenerator[None, None]: + """Context manager for acquiring and releasing a read lock safely.""" + acquire = False + try: + await self.acquire_read() + acquire = True + yield + finally: + if acquire: + with trio.CancelScope() as scope: + scope.shield = True + await self.release_read() + + @asynccontextmanager + async def write_lock(self) -> AsyncGenerator[None, None]: + """Context manager for acquiring and releasing a write lock safely.""" + acquire = False + try: + await self.acquire_write() + acquire = True + yield + finally: + if acquire: + self.release_write() diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index b2711e1a8..bb84a5db6 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -44,6 +44,7 @@ MuxedStreamError, MuxedStreamReset, ) +from libp2p.stream_muxer.rw_lock import ReadWriteLock # Configure logger for this module logger = logging.getLogger("libp2p.stream_muxer.yamux") @@ -80,6 +81,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None: self.send_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() + self.rw_lock = ReadWriteLock() + self.close_lock = trio.Lock() async def __aenter__(self) -> "YamuxStream": """Enter the async context manager.""" @@ -95,52 +98,54 @@ async def __aexit__( await self.close() async def write(self, data: bytes) -> None: - if self.send_closed: - raise MuxedStreamError("Stream is closed for sending") - - # Flow control: Check if we have enough send window - total_len = len(data) - sent = 0 - logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") - while sent < total_len: - # Wait for available window with timeout - timeout = False - async with self.window_lock: - if self.send_window == 0: - logger.debug( - f"Stream {self.stream_id}: Window is zero, waiting for update" - ) - # Release lock and wait with timeout - self.window_lock.release() - # To avoid re-acquiring the lock immediately, - with trio.move_on_after(5.0) as cancel_scope: - while self.send_window == 0 and not self.closed: - await trio.sleep(0.01) - # If we timed out, cancel the scope - timeout = cancel_scope.cancelled_caught - # Re-acquire lock - await self.window_lock.acquire() - - # If we timed out waiting for window update, raise an error - if timeout: - raise MuxedStreamError( - "Timed out waiting for window update after 5 seconds." - ) + async with self.rw_lock.write_lock(): + if self.send_closed: + raise MuxedStreamError("Stream is closed for sending") + + # Flow control: Check if we have enough send window + total_len = len(data) + sent = 0 + logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") + while sent < total_len: + # Wait for available window with timeout + timeout = False + async with self.window_lock: + if self.send_window == 0: + logger.debug( + f"Stream {self.stream_id}: " + "Window is zero, waiting for update" + ) + # Release lock and wait with timeout + self.window_lock.release() + # To avoid re-acquiring the lock immediately, + with trio.move_on_after(5.0) as cancel_scope: + while self.send_window == 0 and not self.closed: + await trio.sleep(0.01) + # If we timed out, cancel the scope + timeout = cancel_scope.cancelled_caught + # Re-acquire lock + await self.window_lock.acquire() + + # If we timed out waiting for window update, raise an error + if timeout: + raise MuxedStreamError( + "Timed out waiting for window update after 5 seconds." + ) - if self.closed: - raise MuxedStreamError("Stream is closed") + if self.closed: + raise MuxedStreamError("Stream is closed") - # Calculate how much we can send now - to_send = min(self.send_window, total_len - sent) - chunk = data[sent : sent + to_send] - self.send_window -= to_send + # Calculate how much we can send now + to_send = min(self.send_window, total_len - sent) + chunk = data[sent : sent + to_send] + self.send_window -= to_send - # Send the data - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) - ) - await self.conn.secured_conn.write(header + chunk) - sent += to_send + # Send the data + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) + ) + await self.conn.secured_conn.write(header + chunk) + sent += to_send async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: """ @@ -257,30 +262,32 @@ async def read(self, n: int | None = -1) -> bytes: return data async def close(self) -> None: - if not self.send_closed: - logger.debug(f"Half-closing stream {self.stream_id} (local end)") - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 - ) - await self.conn.secured_conn.write(header) - self.send_closed = True + async with self.close_lock: + if not self.send_closed: + logger.debug(f"Half-closing stream {self.stream_id} (local end)") + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 + ) + await self.conn.secured_conn.write(header) + self.send_closed = True - # Only set fully closed if both directions are closed - if self.send_closed and self.recv_closed: - self.closed = True - else: - # Stream is half-closed but not fully closed - self.closed = False + # Only set fully closed if both directions are closed + if self.send_closed and self.recv_closed: + self.closed = True + else: + # Stream is half-closed but not fully closed + self.closed = False async def reset(self) -> None: if not self.closed: - logger.debug(f"Resetting stream {self.stream_id}") - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 - ) - await self.conn.secured_conn.write(header) - self.closed = True - self.reset_received = True # Mark as reset + async with self.close_lock: + logger.debug(f"Resetting stream {self.stream_id}") + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 + ) + await self.conn.secured_conn.write(header) + self.closed = True + self.reset_received = True # Mark as reset def set_deadline(self, ttl: int) -> bool: """ diff --git a/newsfragments/897.bugfix.rst b/newsfragments/897.bugfix.rst new file mode 100644 index 000000000..575b5769b --- /dev/null +++ b/newsfragments/897.bugfix.rst @@ -0,0 +1,6 @@ +enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions + +- Implements ReadWriteLock for `YamuxStream` write operations +- Prevents data corruption from concurrent write operations +- Read operations remain lock-free due to existing `Yamux` architecture +- Resolves race conditions identified in Issue #793 diff --git a/tests/conftest.py b/tests/conftest.py index ba3b7da0c..343a03d92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import pytest - @pytest.fixture def security_protocol(): - return None + return None \ No newline at end of file