Skip to content

Commit 83731ce

Browse files
authored
fix(middleware/session): mutex for thread safety (#3049)
* fix(middleware/session): mutex for thread safety * chore: Remove extra release and acquire ctx calls in session_test.go * feat: Remove unnecessary session mutex lock in decodeSessionData function
1 parent dbba6cf commit 83731ce

File tree

3 files changed

+139
-24
lines changed

3 files changed

+139
-24
lines changed

middleware/session/session.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
)
1414

1515
type Session struct {
16+
mu sync.RWMutex // Mutex to protect non-data fields
1617
id string // session id
1718
fresh bool // if new session
1819
ctx fiber.Ctx // fiber context
@@ -56,11 +57,15 @@ func releaseSession(s *Session) {
5657

5758
// Fresh is true if the current session is new
5859
func (s *Session) Fresh() bool {
60+
s.mu.RLock()
61+
defer s.mu.RUnlock()
5962
return s.fresh
6063
}
6164

6265
// ID returns the session id
6366
func (s *Session) ID() string {
67+
s.mu.RLock()
68+
defer s.mu.RUnlock()
6469
return s.id
6570
}
6671

@@ -101,6 +106,9 @@ func (s *Session) Destroy() error {
101106
// Reset local data
102107
s.data.Reset()
103108

109+
s.mu.Lock()
110+
defer s.mu.Unlock()
111+
104112
// Use external Storage if exist
105113
if err := s.config.Storage.Delete(s.id); err != nil {
106114
return err
@@ -113,6 +121,9 @@ func (s *Session) Destroy() error {
113121

114122
// Regenerate generates a new session id and delete the old one from Storage
115123
func (s *Session) Regenerate() error {
124+
s.mu.Lock()
125+
defer s.mu.Unlock()
126+
116127
// Delete old id from storage
117128
if err := s.config.Storage.Delete(s.id); err != nil {
118129
return err
@@ -137,6 +148,9 @@ func (s *Session) Reset() error {
137148
// Reset expiration
138149
s.exp = 0
139150

151+
s.mu.Lock()
152+
defer s.mu.Unlock()
153+
140154
// Delete old id from storage
141155
if err := s.config.Storage.Delete(s.id); err != nil {
142156
return err
@@ -153,10 +167,7 @@ func (s *Session) Reset() error {
153167

154168
// refresh generates a new session, and set session.fresh to be true
155169
func (s *Session) refresh() {
156-
// Create a new id
157170
s.id = s.config.KeyGenerator()
158-
159-
// We assign a new id to the session, so the session must be fresh
160171
s.fresh = true
161172
}
162173

@@ -167,6 +178,9 @@ func (s *Session) Save() error {
167178
return nil
168179
}
169180

181+
s.mu.Lock()
182+
defer s.mu.Unlock()
183+
170184
// Check if session has your own expiration, otherwise use default value
171185
if s.exp <= 0 {
172186
s.exp = s.config.Expiration
@@ -176,25 +190,23 @@ func (s *Session) Save() error {
176190
s.setSession()
177191

178192
// Convert data to bytes
179-
mux.Lock()
180-
defer mux.Unlock()
181193
encCache := gob.NewEncoder(s.byteBuffer)
182194
err := encCache.Encode(&s.data.Data)
183195
if err != nil {
184196
return fmt.Errorf("failed to encode data: %w", err)
185197
}
186198

187-
// copy the data in buffer
199+
// Copy the data in buffer
188200
encodedBytes := make([]byte, s.byteBuffer.Len())
189201
copy(encodedBytes, s.byteBuffer.Bytes())
190202

191-
// pass copied bytes with session id to provider
203+
// Pass copied bytes with session id to provider
192204
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
193205
return err
194206
}
195207

196208
// Release session
197-
// TODO: It's not safe to use the Session after called Save()
209+
// TODO: It's not safe to use the Session after calling Save()
198210
releaseSession(s)
199211

200212
return nil
@@ -210,6 +222,8 @@ func (s *Session) Keys() []string {
210222

211223
// SetExpiry sets a specific expiration for this session
212224
func (s *Session) SetExpiry(exp time.Duration) {
225+
s.mu.Lock()
226+
defer s.mu.Unlock()
213227
s.exp = exp
214228
}
215229

@@ -275,3 +289,13 @@ func (s *Session) delSession() {
275289
fasthttp.ReleaseCookie(fcookie)
276290
}
277291
}
292+
293+
// decodeSessionData decodes the session data from raw bytes.
294+
func (s *Session) decodeSessionData(rawData []byte) error {
295+
_, _ = s.byteBuffer.Write(rawData)
296+
encCache := gob.NewDecoder(s.byteBuffer)
297+
if err := encCache.Decode(&s.data.Data); err != nil {
298+
return fmt.Errorf("failed to decode session data: %w", err)
299+
}
300+
return nil
301+
}

middleware/session/session_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package session
22

33
import (
4+
"errors"
5+
"sync"
46
"testing"
57
"time"
68

@@ -856,3 +858,108 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
856858
})
857859
})
858860
}
861+
862+
// go test -v -race -run Test_Session_Concurrency ./...
863+
func Test_Session_Concurrency(t *testing.T) {
864+
t.Parallel()
865+
app := fiber.New()
866+
store := New()
867+
868+
var wg sync.WaitGroup
869+
errChan := make(chan error, 10) // Buffered channel to collect errors
870+
const numGoroutines = 10 // Number of concurrent goroutines to test
871+
872+
// Start numGoroutines goroutines
873+
for i := 0; i < numGoroutines; i++ {
874+
wg.Add(1)
875+
go func() {
876+
defer wg.Done()
877+
878+
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
879+
880+
sess, err := store.Get(localCtx)
881+
if err != nil {
882+
errChan <- err
883+
return
884+
}
885+
886+
// Set a value
887+
sess.Set("name", "john")
888+
889+
// get the session id
890+
id := sess.ID()
891+
892+
// Check if the session is fresh
893+
if !sess.Fresh() {
894+
errChan <- errors.New("session should be fresh")
895+
return
896+
}
897+
898+
// Save the session
899+
if err := sess.Save(); err != nil {
900+
errChan <- err
901+
return
902+
}
903+
904+
// Release the context
905+
app.ReleaseCtx(localCtx)
906+
907+
// Acquire a new context
908+
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
909+
defer app.ReleaseCtx(localCtx)
910+
911+
// Set the session id in the header
912+
localCtx.Request().Header.SetCookie(store.sessionName, id)
913+
914+
// Get the session
915+
sess, err = store.Get(localCtx)
916+
if err != nil {
917+
errChan <- err
918+
return
919+
}
920+
921+
// Get the value
922+
name := sess.Get("name")
923+
if name != "john" {
924+
errChan <- errors.New("name should be john")
925+
return
926+
}
927+
928+
// Get ID from the session
929+
if sess.ID() != id {
930+
errChan <- errors.New("id should be the same")
931+
return
932+
}
933+
934+
// Check if the session is fresh
935+
if sess.Fresh() {
936+
errChan <- errors.New("session should not be fresh")
937+
return
938+
}
939+
940+
// Delete the key
941+
sess.Delete("name")
942+
943+
// Get the value
944+
name = sess.Get("name")
945+
if name != nil {
946+
errChan <- errors.New("name should be nil")
947+
return
948+
}
949+
950+
// Destroy the session
951+
if err := sess.Destroy(); err != nil {
952+
errChan <- err
953+
return
954+
}
955+
}()
956+
}
957+
958+
wg.Wait() // Wait for all goroutines to finish
959+
close(errChan) // Close the channel to signal no more errors will be sent
960+
961+
// Check for errors sent to errChan
962+
for err := range errChan {
963+
require.NoError(t, err)
964+
}
965+
}

middleware/session/store.go

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/gob"
55
"errors"
66
"fmt"
7-
"sync"
87

98
"github.com/gofiber/fiber/v3"
109
"github.com/gofiber/fiber/v3/internal/storage/memory"
@@ -14,9 +13,6 @@ import (
1413
// ErrEmptySessionID is an error that occurs when the session ID is empty.
1514
var ErrEmptySessionID = errors.New("session id cannot be empty")
1615

17-
// mux is a global mutex for session operations.
18-
var mux sync.Mutex
19-
2016
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
2117
type sessionIDKey int
2218

@@ -132,15 +128,3 @@ func (s *Store) Delete(id string) error {
132128
}
133129
return s.Storage.Delete(id)
134130
}
135-
136-
// decodeSessionData decodes the session data from raw bytes.
137-
func (s *Session) decodeSessionData(rawData []byte) error {
138-
mux.Lock()
139-
defer mux.Unlock()
140-
_, _ = s.byteBuffer.Write(rawData)
141-
encCache := gob.NewDecoder(s.byteBuffer)
142-
if err := encCache.Decode(&s.data.Data); err != nil {
143-
return fmt.Errorf("failed to decode session data: %w", err)
144-
}
145-
return nil
146-
}

0 commit comments

Comments
 (0)