Skip to content

Commit e676a83

Browse files
Merge pull request #2874 from lucas-clemente/fix-accept-stream-race
fix race condition when accepting streams
2 parents 629272c + 46991ae commit e676a83

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

streams_map_incoming_bidi.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ type incomingBidiStreamsMap struct {
2222
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
2323

2424
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
25-
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
25+
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
2626
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
2727
maxNumStreams uint64 // maximum number of streams
2828

2929
newStream func(protocol.StreamNum) streamI
3030
queueMaxStreamID func(*wire.MaxStreamsFrame)
31-
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
3231

3332
closeErr error
3433
}
@@ -39,7 +38,7 @@ func newIncomingBidiStreamsMap(
3938
queueControlFrame func(wire.Frame),
4039
) *incomingBidiStreamsMap {
4140
return &incomingBidiStreamsMap{
42-
newStreamChan: make(chan struct{}),
41+
newStreamChan: make(chan struct{}, 1),
4342
streams: make(map[protocol.StreamNum]streamI),
4443
streamsToDelete: make(map[protocol.StreamNum]struct{}),
4544
maxStream: protocol.StreamNum(maxStreams),
@@ -52,6 +51,12 @@ func newIncomingBidiStreamsMap(
5251
}
5352

5453
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
54+
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
55+
select {
56+
case <-m.newStreamChan:
57+
default:
58+
}
59+
5560
m.mutex.Lock()
5661

5762
var num protocol.StreamNum

streams_map_incoming_generic.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ type incomingItemsMap struct {
2020
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
2121

2222
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
23-
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
23+
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
2424
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
2525
maxNumStreams uint64 // maximum number of streams
2626

2727
newStream func(protocol.StreamNum) item
2828
queueMaxStreamID func(*wire.MaxStreamsFrame)
29-
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
3029

3130
closeErr error
3231
}
@@ -37,7 +36,7 @@ func newIncomingItemsMap(
3736
queueControlFrame func(wire.Frame),
3837
) *incomingItemsMap {
3938
return &incomingItemsMap{
40-
newStreamChan: make(chan struct{}),
39+
newStreamChan: make(chan struct{}, 1),
4140
streams: make(map[protocol.StreamNum]item),
4241
streamsToDelete: make(map[protocol.StreamNum]struct{}),
4342
maxStream: protocol.StreamNum(maxStreams),
@@ -50,6 +49,12 @@ func newIncomingItemsMap(
5049
}
5150

5251
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
52+
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
53+
select {
54+
case <-m.newStreamChan:
55+
default:
56+
}
57+
5358
m.mutex.Lock()
5459

5560
var num protocol.StreamNum

streams_map_incoming_generic_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"bytes"
55
"context"
66
"errors"
7+
"math/rand"
8+
"time"
79

810
"github.com/golang/mock/gomock"
911
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -257,4 +259,44 @@ var _ = Describe("Streams Map (incoming)", func() {
257259
Expect(m.DeleteStream(1)).To(Succeed())
258260
})
259261
})
262+
263+
Context("randomized tests", func() {
264+
const num = 1000
265+
266+
BeforeEach(func() { maxNumStreams = num })
267+
268+
It("opens and accepts streams", func() {
269+
rand.Seed(GinkgoRandomSeed())
270+
ids := make([]protocol.StreamNum, num)
271+
for i := 0; i < num; i++ {
272+
ids[i] = protocol.StreamNum(i + 1)
273+
}
274+
rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
275+
276+
const timeout = 5 * time.Second
277+
done := make(chan struct{}, 2)
278+
go func() {
279+
defer GinkgoRecover()
280+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
281+
defer cancel()
282+
for i := 0; i < num; i++ {
283+
_, err := m.AcceptStream(ctx)
284+
Expect(err).ToNot(HaveOccurred())
285+
}
286+
done <- struct{}{}
287+
}()
288+
289+
go func() {
290+
defer GinkgoRecover()
291+
for i := 0; i < num; i++ {
292+
_, err := m.GetOrOpenStream(ids[i])
293+
Expect(err).ToNot(HaveOccurred())
294+
}
295+
done <- struct{}{}
296+
}()
297+
298+
Eventually(done, timeout*3/2).Should(Receive())
299+
Eventually(done, timeout*3/2).Should(Receive())
300+
})
301+
})
260302
})

streams_map_incoming_uni.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ type incomingUniStreamsMap struct {
2222
streamsToDelete map[protocol.StreamNum]struct{} // used as a set
2323

2424
nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream()
25-
nextStreamToOpen protocol.StreamNum // the highest stream that the peer openend
25+
nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened
2626
maxStream protocol.StreamNum // the highest stream that the peer is allowed to open
2727
maxNumStreams uint64 // maximum number of streams
2828

2929
newStream func(protocol.StreamNum) receiveStreamI
3030
queueMaxStreamID func(*wire.MaxStreamsFrame)
31-
// streamNumToID func(protocol.StreamNum) protocol.StreamID // only used for generating errors
3231

3332
closeErr error
3433
}
@@ -39,7 +38,7 @@ func newIncomingUniStreamsMap(
3938
queueControlFrame func(wire.Frame),
4039
) *incomingUniStreamsMap {
4140
return &incomingUniStreamsMap{
42-
newStreamChan: make(chan struct{}),
41+
newStreamChan: make(chan struct{}, 1),
4342
streams: make(map[protocol.StreamNum]receiveStreamI),
4443
streamsToDelete: make(map[protocol.StreamNum]struct{}),
4544
maxStream: protocol.StreamNum(maxStreams),
@@ -52,6 +51,12 @@ func newIncomingUniStreamsMap(
5251
}
5352

5453
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
54+
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
55+
select {
56+
case <-m.newStreamChan:
57+
default:
58+
}
59+
5560
m.mutex.Lock()
5661

5762
var num protocol.StreamNum

0 commit comments

Comments
 (0)