Skip to content

Commit c5a2836

Browse files
committed
stream_muxer(yamux): add ReadWriteLock to YamuxStream to prevent concurrent read/write corruption
Introduce a read/write lock abstraction and integrate it into `YamuxStream` so that simultaneous reads and writes do not interleave, eliminating potential data corruption and race conditions. Major changes: - Abstract `ReadWriteLock` into its own util module - Integrate locking into YamuxStream for `write` operations - Ensure tests pass for lock correctness - Fix lint & type issues discovered during review Closes #793
1 parent 74f4aaf commit c5a2836

File tree

5 files changed

+148
-133
lines changed

5 files changed

+148
-133
lines changed

libp2p/stream_muxer/mplex/mplex_stream.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections.abc import AsyncGenerator
2-
from contextlib import asynccontextmanager
31
from types import (
42
TracebackType,
53
)
@@ -15,6 +13,7 @@
1513
from libp2p.stream_muxer.exceptions import (
1614
MuxedConnUnavailable,
1715
)
16+
from libp2p.stream_muxer.rw_lock import ReadWriteLock
1817

1918
from .constants import (
2019
HeaderTags,
@@ -34,72 +33,6 @@
3433
)
3534

3635

37-
class ReadWriteLock:
38-
"""
39-
A read-write lock that allows multiple concurrent readers
40-
or one exclusive writer, implemented using Trio primitives.
41-
"""
42-
43-
def __init__(self) -> None:
44-
self._readers = 0
45-
self._readers_lock = trio.Lock() # Protects access to _readers count
46-
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
47-
48-
async def acquire_read(self) -> None:
49-
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
50-
try:
51-
async with self._readers_lock:
52-
if self._readers == 0:
53-
await self._writer_lock.acquire()
54-
self._readers += 1
55-
except trio.Cancelled:
56-
raise
57-
58-
async def release_read(self) -> None:
59-
"""Release a read lock."""
60-
async with self._readers_lock:
61-
if self._readers == 1:
62-
self._writer_lock.release()
63-
self._readers -= 1
64-
65-
async def acquire_write(self) -> None:
66-
"""Acquire an exclusive write lock."""
67-
try:
68-
await self._writer_lock.acquire()
69-
except trio.Cancelled:
70-
raise
71-
72-
def release_write(self) -> None:
73-
"""Release the exclusive write lock."""
74-
self._writer_lock.release()
75-
76-
@asynccontextmanager
77-
async def read_lock(self) -> AsyncGenerator[None, None]:
78-
"""Context manager for acquiring and releasing a read lock safely."""
79-
acquire = False
80-
try:
81-
await self.acquire_read()
82-
acquire = True
83-
yield
84-
finally:
85-
if acquire:
86-
with trio.CancelScope() as scope:
87-
scope.shield = True
88-
await self.release_read()
89-
90-
@asynccontextmanager
91-
async def write_lock(self) -> AsyncGenerator[None, None]:
92-
"""Context manager for acquiring and releasing a write lock safely."""
93-
acquire = False
94-
try:
95-
await self.acquire_write()
96-
acquire = True
97-
yield
98-
finally:
99-
if acquire:
100-
self.release_write()
101-
102-
10336
class MplexStream(IMuxedStream):
10437
"""
10538
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go

libp2p/stream_muxer/rw_lock.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from collections.abc import AsyncGenerator
2+
from contextlib import asynccontextmanager
3+
4+
import trio
5+
6+
7+
class ReadWriteLock:
8+
"""
9+
A read-write lock that allows multiple concurrent readers
10+
or one exclusive writer, implemented using Trio primitives.
11+
"""
12+
13+
def __init__(self) -> None:
14+
self._readers = 0
15+
self._readers_lock = trio.Lock() # Protects access to _readers count
16+
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
17+
18+
async def acquire_read(self) -> None:
19+
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
20+
try:
21+
async with self._readers_lock:
22+
if self._readers == 0:
23+
await self._writer_lock.acquire()
24+
self._readers += 1
25+
except trio.Cancelled:
26+
raise
27+
28+
async def release_read(self) -> None:
29+
"""Release a read lock."""
30+
async with self._readers_lock:
31+
if self._readers == 1:
32+
self._writer_lock.release()
33+
self._readers -= 1
34+
35+
async def acquire_write(self) -> None:
36+
"""Acquire an exclusive write lock."""
37+
try:
38+
await self._writer_lock.acquire()
39+
except trio.Cancelled:
40+
raise
41+
42+
def release_write(self) -> None:
43+
"""Release the exclusive write lock."""
44+
self._writer_lock.release()
45+
46+
@asynccontextmanager
47+
async def read_lock(self) -> AsyncGenerator[None, None]:
48+
"""Context manager for acquiring and releasing a read lock safely."""
49+
acquire = False
50+
try:
51+
await self.acquire_read()
52+
acquire = True
53+
yield
54+
finally:
55+
if acquire:
56+
with trio.CancelScope() as scope:
57+
scope.shield = True
58+
await self.release_read()
59+
60+
@asynccontextmanager
61+
async def write_lock(self) -> AsyncGenerator[None, None]:
62+
"""Context manager for acquiring and releasing a write lock safely."""
63+
acquire = False
64+
try:
65+
await self.acquire_write()
66+
acquire = True
67+
yield
68+
finally:
69+
if acquire:
70+
self.release_write()

libp2p/stream_muxer/yamux/yamux.py

Lines changed: 70 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
MuxedStreamError,
4545
MuxedStreamReset,
4646
)
47+
from libp2p.stream_muxer.rw_lock import ReadWriteLock
4748

4849
# Configure logger for this module
4950
logger = logging.getLogger("libp2p.stream_muxer.yamux")
@@ -80,6 +81,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None:
8081
self.send_window = DEFAULT_WINDOW_SIZE
8182
self.recv_window = DEFAULT_WINDOW_SIZE
8283
self.window_lock = trio.Lock()
84+
self.rw_lock = ReadWriteLock()
85+
self.close_lock = trio.Lock()
8386

8487
async def __aenter__(self) -> "YamuxStream":
8588
"""Enter the async context manager."""
@@ -95,52 +98,54 @@ async def __aexit__(
9598
await self.close()
9699

97100
async def write(self, data: bytes) -> None:
98-
if self.send_closed:
99-
raise MuxedStreamError("Stream is closed for sending")
100-
101-
# Flow control: Check if we have enough send window
102-
total_len = len(data)
103-
sent = 0
104-
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
105-
while sent < total_len:
106-
# Wait for available window with timeout
107-
timeout = False
108-
async with self.window_lock:
109-
if self.send_window == 0:
110-
logger.debug(
111-
f"Stream {self.stream_id}: Window is zero, waiting for update"
112-
)
113-
# Release lock and wait with timeout
114-
self.window_lock.release()
115-
# To avoid re-acquiring the lock immediately,
116-
with trio.move_on_after(5.0) as cancel_scope:
117-
while self.send_window == 0 and not self.closed:
118-
await trio.sleep(0.01)
119-
# If we timed out, cancel the scope
120-
timeout = cancel_scope.cancelled_caught
121-
# Re-acquire lock
122-
await self.window_lock.acquire()
123-
124-
# If we timed out waiting for window update, raise an error
125-
if timeout:
126-
raise MuxedStreamError(
127-
"Timed out waiting for window update after 5 seconds."
128-
)
101+
async with self.rw_lock.write_lock():
102+
if self.send_closed:
103+
raise MuxedStreamError("Stream is closed for sending")
104+
105+
# Flow control: Check if we have enough send window
106+
total_len = len(data)
107+
sent = 0
108+
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
109+
while sent < total_len:
110+
# Wait for available window with timeout
111+
timeout = False
112+
async with self.window_lock:
113+
if self.send_window == 0:
114+
logger.debug(
115+
f"Stream {self.stream_id}: "
116+
"Window is zero, waiting for update"
117+
)
118+
# Release lock and wait with timeout
119+
self.window_lock.release()
120+
# To avoid re-acquiring the lock immediately,
121+
with trio.move_on_after(5.0) as cancel_scope:
122+
while self.send_window == 0 and not self.closed:
123+
await trio.sleep(0.01)
124+
# If we timed out, cancel the scope
125+
timeout = cancel_scope.cancelled_caught
126+
# Re-acquire lock
127+
await self.window_lock.acquire()
128+
129+
# If we timed out waiting for window update, raise an error
130+
if timeout:
131+
raise MuxedStreamError(
132+
"Timed out waiting for window update after 5 seconds."
133+
)
129134

130-
if self.closed:
131-
raise MuxedStreamError("Stream is closed")
135+
if self.closed:
136+
raise MuxedStreamError("Stream is closed")
132137

133-
# Calculate how much we can send now
134-
to_send = min(self.send_window, total_len - sent)
135-
chunk = data[sent : sent + to_send]
136-
self.send_window -= to_send
138+
# Calculate how much we can send now
139+
to_send = min(self.send_window, total_len - sent)
140+
chunk = data[sent : sent + to_send]
141+
self.send_window -= to_send
137142

138-
# Send the data
139-
header = struct.pack(
140-
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
141-
)
142-
await self.conn.secured_conn.write(header + chunk)
143-
sent += to_send
143+
# Send the data
144+
header = struct.pack(
145+
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
146+
)
147+
await self.conn.secured_conn.write(header + chunk)
148+
sent += to_send
144149

145150
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
146151
"""
@@ -257,30 +262,32 @@ async def read(self, n: int | None = -1) -> bytes:
257262
return data
258263

259264
async def close(self) -> None:
260-
if not self.send_closed:
261-
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
262-
header = struct.pack(
263-
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
264-
)
265-
await self.conn.secured_conn.write(header)
266-
self.send_closed = True
265+
async with self.close_lock:
266+
if not self.send_closed:
267+
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
268+
header = struct.pack(
269+
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
270+
)
271+
await self.conn.secured_conn.write(header)
272+
self.send_closed = True
267273

268-
# Only set fully closed if both directions are closed
269-
if self.send_closed and self.recv_closed:
270-
self.closed = True
271-
else:
272-
# Stream is half-closed but not fully closed
273-
self.closed = False
274+
# Only set fully closed if both directions are closed
275+
if self.send_closed and self.recv_closed:
276+
self.closed = True
277+
else:
278+
# Stream is half-closed but not fully closed
279+
self.closed = False
274280

275281
async def reset(self) -> None:
276282
if not self.closed:
277-
logger.debug(f"Resetting stream {self.stream_id}")
278-
header = struct.pack(
279-
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
280-
)
281-
await self.conn.secured_conn.write(header)
282-
self.closed = True
283-
self.reset_received = True # Mark as reset
283+
async with self.close_lock:
284+
logger.debug(f"Resetting stream {self.stream_id}")
285+
header = struct.pack(
286+
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
287+
)
288+
await self.conn.secured_conn.write(header)
289+
self.closed = True
290+
self.reset_received = True # Mark as reset
284291

285292
def set_deadline(self, ttl: int) -> bool:
286293
"""

newsfragments/897.bugfix.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions
2+
3+
- Implements ReadWriteLock for `YamuxStream` write operations
4+
- Prevents data corruption from concurrent write operations
5+
- Read operations remain lock-free due to existing `Yamux` architecture
6+
- Resolves race conditions identified in Issue #793

tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22

3-
43
@pytest.fixture
54
def security_protocol():
6-
return None
5+
return None

0 commit comments

Comments
 (0)