Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 1 addition & 68 deletions libp2p/stream_muxer/mplex/mplex_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
Expand All @@ -15,6 +13,7 @@
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.rw_lock import ReadWriteLock

from .constants import (
HeaderTags,
Expand All @@ -34,72 +33,6 @@
)


class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""

def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time

async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise

async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1

async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise

def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()

@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()

@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()


class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
Expand Down
70 changes: 70 additions & 0 deletions libp2p/stream_muxer/rw_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

import trio


class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""

def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time

async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise

async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1

async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise

def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()

@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()

@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
133 changes: 70 additions & 63 deletions libp2p/stream_muxer/yamux/yamux.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MuxedStreamError,
MuxedStreamReset,
)
from libp2p.stream_muxer.rw_lock import ReadWriteLock

# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None:
self.send_window = DEFAULT_WINDOW_SIZE
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.close_lock = trio.Lock()

async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
Expand All @@ -95,52 +98,54 @@ async def __aexit__(
await self.close()

async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")

# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()

# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
async with self.rw_lock.write_lock():
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")

# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
f"Stream {self.stream_id}: "
"Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()

# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)

if self.closed:
raise MuxedStreamError("Stream is closed")
if self.closed:
raise MuxedStreamError("Stream is closed")

# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send
# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send

# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
await self.conn.secured_conn.write(header + chunk)
sent += to_send
# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
await self.conn.secured_conn.write(header + chunk)
sent += to_send

async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
"""
Expand Down Expand Up @@ -257,30 +262,32 @@ async def read(self, n: int | None = -1) -> bytes:
return data

async def close(self) -> None:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True
async with self.close_lock:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True

# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False
# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False

async def reset(self) -> None:
if not self.closed:
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset
async with self.close_lock:
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset

def set_deadline(self, ttl: int) -> bool:
"""
Expand Down
6 changes: 6 additions & 0 deletions newsfragments/897.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions

- Implements ReadWriteLock for `YamuxStream` write operations
- Prevents data corruption from concurrent write operations
- Read operations remain lock-free due to existing `Yamux` architecture
- Resolves race conditions identified in Issue #793
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest


@pytest.fixture
def security_protocol():
return None
return None