44
44
MuxedStreamError ,
45
45
MuxedStreamReset ,
46
46
)
47
+ from libp2p .stream_muxer .rw_lock import ReadWriteLock
47
48
48
49
# Configure logger for this module
49
50
logger = logging .getLogger ("libp2p.stream_muxer.yamux" )
@@ -80,6 +81,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None:
80
81
self .send_window = DEFAULT_WINDOW_SIZE
81
82
self .recv_window = DEFAULT_WINDOW_SIZE
82
83
self .window_lock = trio .Lock ()
84
+ self .rw_lock = ReadWriteLock ()
85
+ self .close_lock = trio .Lock ()
83
86
84
87
async def __aenter__ (self ) -> "YamuxStream" :
85
88
"""Enter the async context manager."""
@@ -95,52 +98,54 @@ async def __aexit__(
95
98
await self .close ()
96
99
97
100
async def write (self , data : bytes ) -> None :
98
- if self .send_closed :
99
- raise MuxedStreamError ("Stream is closed for sending" )
100
-
101
- # Flow control: Check if we have enough send window
102
- total_len = len (data )
103
- sent = 0
104
- logger .debug (f"Stream { self .stream_id } : Starts writing { total_len } bytes " )
105
- while sent < total_len :
106
- # Wait for available window with timeout
107
- timeout = False
108
- async with self .window_lock :
109
- if self .send_window == 0 :
110
- logger .debug (
111
- f"Stream { self .stream_id } : Window is zero, waiting for update"
112
- )
113
- # Release lock and wait with timeout
114
- self .window_lock .release ()
115
- # To avoid re-acquiring the lock immediately,
116
- with trio .move_on_after (5.0 ) as cancel_scope :
117
- while self .send_window == 0 and not self .closed :
118
- await trio .sleep (0.01 )
119
- # If we timed out, cancel the scope
120
- timeout = cancel_scope .cancelled_caught
121
- # Re-acquire lock
122
- await self .window_lock .acquire ()
123
-
124
- # If we timed out waiting for window update, raise an error
125
- if timeout :
126
- raise MuxedStreamError (
127
- "Timed out waiting for window update after 5 seconds."
128
- )
101
+ async with self .rw_lock .write_lock ():
102
+ if self .send_closed :
103
+ raise MuxedStreamError ("Stream is closed for sending" )
104
+
105
+ # Flow control: Check if we have enough send window
106
+ total_len = len (data )
107
+ sent = 0
108
+ logger .debug (f"Stream { self .stream_id } : Starts writing { total_len } bytes " )
109
+ while sent < total_len :
110
+ # Wait for available window with timeout
111
+ timeout = False
112
+ async with self .window_lock :
113
+ if self .send_window == 0 :
114
+ logger .debug (
115
+ f"Stream { self .stream_id } : "
116
+ "Window is zero, waiting for update"
117
+ )
118
+ # Release lock and wait with timeout
119
+ self .window_lock .release ()
120
+ # To avoid re-acquiring the lock immediately,
121
+ with trio .move_on_after (5.0 ) as cancel_scope :
122
+ while self .send_window == 0 and not self .closed :
123
+ await trio .sleep (0.01 )
124
+ # If we timed out, cancel the scope
125
+ timeout = cancel_scope .cancelled_caught
126
+ # Re-acquire lock
127
+ await self .window_lock .acquire ()
128
+
129
+ # If we timed out waiting for window update, raise an error
130
+ if timeout :
131
+ raise MuxedStreamError (
132
+ "Timed out waiting for window update after 5 seconds."
133
+ )
129
134
130
- if self .closed :
131
- raise MuxedStreamError ("Stream is closed" )
135
+ if self .closed :
136
+ raise MuxedStreamError ("Stream is closed" )
132
137
133
- # Calculate how much we can send now
134
- to_send = min (self .send_window , total_len - sent )
135
- chunk = data [sent : sent + to_send ]
136
- self .send_window -= to_send
138
+ # Calculate how much we can send now
139
+ to_send = min (self .send_window , total_len - sent )
140
+ chunk = data [sent : sent + to_send ]
141
+ self .send_window -= to_send
137
142
138
- # Send the data
139
- header = struct .pack (
140
- YAMUX_HEADER_FORMAT , 0 , TYPE_DATA , 0 , self .stream_id , len (chunk )
141
- )
142
- await self .conn .secured_conn .write (header + chunk )
143
- sent += to_send
143
+ # Send the data
144
+ header = struct .pack (
145
+ YAMUX_HEADER_FORMAT , 0 , TYPE_DATA , 0 , self .stream_id , len (chunk )
146
+ )
147
+ await self .conn .secured_conn .write (header + chunk )
148
+ sent += to_send
144
149
145
150
async def send_window_update (self , increment : int , skip_lock : bool = False ) -> None :
146
151
"""
@@ -257,30 +262,32 @@ async def read(self, n: int | None = -1) -> bytes:
257
262
return data
258
263
259
264
async def close (self ) -> None :
260
- if not self .send_closed :
261
- logger .debug (f"Half-closing stream { self .stream_id } (local end)" )
262
- header = struct .pack (
263
- YAMUX_HEADER_FORMAT , 0 , TYPE_DATA , FLAG_FIN , self .stream_id , 0
264
- )
265
- await self .conn .secured_conn .write (header )
266
- self .send_closed = True
265
+ async with self .close_lock :
266
+ if not self .send_closed :
267
+ logger .debug (f"Half-closing stream { self .stream_id } (local end)" )
268
+ header = struct .pack (
269
+ YAMUX_HEADER_FORMAT , 0 , TYPE_DATA , FLAG_FIN , self .stream_id , 0
270
+ )
271
+ await self .conn .secured_conn .write (header )
272
+ self .send_closed = True
267
273
268
- # Only set fully closed if both directions are closed
269
- if self .send_closed and self .recv_closed :
270
- self .closed = True
271
- else :
272
- # Stream is half-closed but not fully closed
273
- self .closed = False
274
+ # Only set fully closed if both directions are closed
275
+ if self .send_closed and self .recv_closed :
276
+ self .closed = True
277
+ else :
278
+ # Stream is half-closed but not fully closed
279
+ self .closed = False
274
280
275
281
async def reset (self ) -> None :
276
282
if not self .closed :
277
- logger .debug (f"Resetting stream { self .stream_id } " )
278
- header = struct .pack (
279
- YAMUX_HEADER_FORMAT , 0 , TYPE_DATA , FLAG_RST , self .stream_id , 0
280
- )
281
- await self .conn .secured_conn .write (header )
282
- self .closed = True
283
- self .reset_received = True # Mark as reset
283
+ async with self .close_lock :
284
+ logger .debug (f"Resetting stream { self .stream_id } " )
285
+ header = struct .pack (
286
+ YAMUX_HEADER_FORMAT , 0 , TYPE_DATA , FLAG_RST , self .stream_id , 0
287
+ )
288
+ await self .conn .secured_conn .write (header )
289
+ self .closed = True
290
+ self .reset_received = True # Mark as reset
284
291
285
292
def set_deadline (self , ttl : int ) -> bool :
286
293
"""
0 commit comments