Skip to content

Commit 595c47b

Browse files
committed
refactor: replace atomic operations with atomic types
1 parent b3a2091 commit 595c47b

File tree

6 files changed

+50
-51
lines changed

6 files changed

+50
-51
lines changed

db_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -717,9 +717,9 @@ var _ = Describe("CopyFrom/CopyTo", func() {
717717
Expect(res.RowsAffected()).To(Equal(n))
718718

719719
st := db.PoolStats()
720-
Expect(st.Hits).To(Equal(uint32(4)))
721-
Expect(st.Misses).To(Equal(uint32(1)))
722-
Expect(st.Timeouts).To(Equal(uint32(0)))
720+
Expect(st.Hits.Load()).To(Equal(uint32(4)))
721+
Expect(st.Misses.Load()).To(Equal(uint32(1)))
722+
Expect(st.Timeouts.Load()).To(Equal(uint32(0)))
723723
Expect(st.TotalConns).To(Equal(uint32(1)))
724724
Expect(st.IdleConns).To(Equal(uint32(1)))
725725

@@ -736,9 +736,9 @@ var _ = Describe("CopyFrom/CopyTo", func() {
736736
Expect(res).To(BeNil())
737737

738738
st := db.Pool().Stats()
739-
Expect(st.Hits).To(Equal(uint32(3)))
740-
Expect(st.Misses).To(Equal(uint32(1)))
741-
Expect(st.Timeouts).To(Equal(uint32(0)))
739+
Expect(st.Hits.Load()).To(Equal(uint32(3)))
740+
Expect(st.Misses.Load()).To(Equal(uint32(1)))
741+
Expect(st.Timeouts.Load()).To(Equal(uint32(0)))
742742
Expect(st.TotalConns).To(Equal(uint32(1)))
743743
Expect(st.IdleConns).To(Equal(uint32(1)))
744744

internal/pool/conn.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type Conn struct {
2020
lastID int64
2121

2222
createdAt time.Time
23-
usedAt uint32 // atomic
23+
usedAt atomic.Uint32
2424
pooled bool
2525
Inited bool
2626
}
@@ -36,12 +36,12 @@ func NewConn(netConn net.Conn, pool *ConnPool) *Conn {
3636
}
3737

3838
func (cn *Conn) UsedAt() time.Time {
39-
unix := atomic.LoadUint32(&cn.usedAt)
39+
unix := cn.usedAt.Load()
4040
return time.Unix(int64(unix), 0)
4141
}
4242

4343
func (cn *Conn) SetUsedAt(tm time.Time) {
44-
atomic.StoreUint32(&cn.usedAt, uint32(tm.Unix()))
44+
cn.usedAt.Store(uint32(tm.Unix()))
4545
}
4646

4747
func (cn *Conn) RemoteAddr() net.Addr {

internal/pool/pool.go

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ var timers = sync.Pool{
2626

2727
// Stats contains pool state information and accumulated stats.
2828
type Stats struct {
29-
Hits uint32 // number of times free connection was found in the pool
30-
Misses uint32 // number of times free connection was NOT found in the pool
31-
Timeouts uint32 // number of times a wait timeout occurred
29+
Hits atomic.Uint32 // number of times free connection was found in the pool
30+
Misses atomic.Uint32 // number of times free connection was NOT found in the pool
31+
Timeouts atomic.Uint32 // number of times a wait timeout occurred
3232

33-
TotalConns uint32 // number of total connections in the pool
34-
IdleConns uint32 // number of idle connections in the pool
35-
StaleConns uint32 // number of stale connections removed from the pool
33+
TotalConns uint32 // number of total connections in the pool
34+
IdleConns uint32 // number of idle connections in the pool
35+
StaleConns atomic.Uint32 // number of stale connections removed from the pool
3636
}
3737

3838
type Pooler interface {
@@ -71,9 +71,9 @@ type Options struct {
7171
type ConnPool struct {
7272
opt *Options
7373

74-
dialErrorsNum uint32 // atomic
74+
dialErrorsNum atomic.Uint32
7575

76-
_closed uint32 // atomic
76+
_closed atomic.Bool
7777

7878
lastDialErrorMu sync.RWMutex
7979
lastDialError error
@@ -188,14 +188,14 @@ func (p *ConnPool) dialConn(c context.Context, pooled bool) (*Conn, error) {
188188
return nil, ErrClosed
189189
}
190190

191-
if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
191+
if p.dialErrorsNum.Load() >= uint32(p.opt.PoolSize) {
192192
return nil, p.getLastDialError()
193193
}
194194

195195
netConn, err := p.opt.Dialer(c)
196196
if err != nil {
197197
p.setLastDialError(err)
198-
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
198+
if p.dialErrorsNum.Add(1) == uint32(p.opt.PoolSize) {
199199
go p.tryDial()
200200
}
201201
return nil, err
@@ -219,7 +219,7 @@ func (p *ConnPool) tryDial() {
219219
continue
220220
}
221221

222-
atomic.StoreUint32(&p.dialErrorsNum, 0)
222+
p.dialErrorsNum.Store(0)
223223
_ = conn.Close()
224224
return
225225
}
@@ -263,11 +263,11 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
263263
continue
264264
}
265265

266-
atomic.AddUint32(&p.stats.Hits, 1)
266+
p.stats.Hits.Add(1)
267267
return cn, nil
268268
}
269269

270-
atomic.AddUint32(&p.stats.Misses, 1)
270+
p.stats.Misses.Add(1)
271271

272272
newcn, err := p.newConn(ctx, true)
273273
if err != nil {
@@ -313,7 +313,7 @@ func (p *ConnPool) waitTurn(c context.Context) error {
313313
return nil
314314
case <-timer.C:
315315
timers.Put(timer)
316-
atomic.AddUint32(&p.stats.Timeouts, 1)
316+
p.stats.Timeouts.Add(1)
317317
return ErrPoolTimeout
318318
}
319319
}
@@ -402,20 +402,19 @@ func (p *ConnPool) IdleLen() int {
402402
}
403403

404404
func (p *ConnPool) Stats() *Stats {
405-
idleLen := p.IdleLen()
406-
return &Stats{
407-
Hits: atomic.LoadUint32(&p.stats.Hits),
408-
Misses: atomic.LoadUint32(&p.stats.Misses),
409-
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
410-
405+
stats := &Stats{
411406
TotalConns: uint32(p.Len()),
412-
IdleConns: uint32(idleLen),
413-
StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
407+
IdleConns: uint32(p.IdleLen()),
414408
}
409+
stats.Hits.Store(p.stats.Hits.Load())
410+
stats.Misses.Store(p.stats.Misses.Load())
411+
stats.Timeouts.Store(p.stats.Timeouts.Load())
412+
stats.StaleConns.Store(p.stats.StaleConns.Load())
413+
return stats
415414
}
416415

417416
func (p *ConnPool) closed() bool {
418-
return atomic.LoadUint32(&p._closed) == 1
417+
return p._closed.Load()
419418
}
420419

421420
func (p *ConnPool) Filter(fn func(*Conn) bool) error {
@@ -433,7 +432,7 @@ func (p *ConnPool) Filter(fn func(*Conn) bool) error {
433432
}
434433

435434
func (p *ConnPool) Close() error {
436-
if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
435+
if !p._closed.CompareAndSwap(false, true) {
437436
return ErrClosed
438437
}
439438

@@ -466,7 +465,7 @@ func (p *ConnPool) reaper(frequency time.Duration) {
466465
internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err)
467466
continue
468467
}
469-
atomic.AddUint32(&p.stats.StaleConns, uint32(n))
468+
p.stats.StaleConns.Add(uint32(n))
470469
}
471470
}
472471

internal/pool/pool_sticky.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ func (e BadConnError) Unwrap() error {
3535

3636
type StickyConnPool struct {
3737
pool Pooler
38-
shared int32 // atomic
38+
shared atomic.Int32
3939

40-
state uint32 // atomic
40+
state atomic.Uint32
4141
ch chan *Conn
4242

4343
_badConnError atomic.Value
@@ -53,7 +53,7 @@ func NewStickyConnPool(pool Pooler) *StickyConnPool {
5353
ch: make(chan *Conn, 1),
5454
}
5555
}
56-
atomic.AddInt32(&p.shared, 1)
56+
p.shared.Add(1)
5757
return p
5858
}
5959

@@ -68,13 +68,13 @@ func (p *StickyConnPool) CloseConn(cn *Conn) error {
6868
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
6969
// In worst case this races with Close which is not a very common operation.
7070
for i := 0; i < 1000; i++ {
71-
switch atomic.LoadUint32(&p.state) {
71+
switch p.state.Load() {
7272
case stateDefault:
7373
cn, err := p.pool.Get(ctx)
7474
if err != nil {
7575
return nil, err
7676
}
77-
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
77+
if p.state.CompareAndSwap(stateDefault, stateInited) {
7878
return cn, nil
7979
}
8080
p.pool.Remove(ctx, cn, ErrClosed)
@@ -124,16 +124,16 @@ func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
124124
}
125125

126126
func (p *StickyConnPool) Close() error {
127-
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
127+
if shared := p.shared.Add(-1); shared > 0 {
128128
return nil
129129
}
130130

131131
for i := 0; i < 1000; i++ {
132-
state := atomic.LoadUint32(&p.state)
132+
state := p.state.Load()
133133
if state == stateClosed {
134134
return ErrClosed
135135
}
136-
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
136+
if p.state.CompareAndSwap(state, stateClosed) {
137137
close(p.ch)
138138
cn, ok := <-p.ch
139139
if ok {
@@ -162,8 +162,8 @@ func (p *StickyConnPool) Reset(ctx context.Context) error {
162162
return errors.New("pg: StickyConnPool does not have a Conn")
163163
}
164164

165-
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
166-
state := atomic.LoadUint32(&p.state)
165+
if !p.state.CompareAndSwap(stateInited, stateDefault) {
166+
state := p.state.Load()
167167
return fmt.Errorf("pg: invalid StickyConnPool state: %d", state)
168168
}
169169

@@ -181,7 +181,7 @@ func (p *StickyConnPool) badConnError() error {
181181
}
182182

183183
func (p *StickyConnPool) Len() int {
184-
switch atomic.LoadUint32(&p.state) {
184+
switch p.state.Load() {
185185
case stateDefault:
186186
return 0
187187
case stateInited:

listener_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ var _ = Context("Listener", func() {
4343
}
4444

4545
st := db.PoolStats()
46-
Expect(st.Hits).To(Equal(uint32(0)))
47-
Expect(st.Misses).To(Equal(uint32(0)))
48-
Expect(st.Timeouts).To(Equal(uint32(0)))
46+
Expect(st.Hits.Load()).To(Equal(uint32(0)))
47+
Expect(st.Misses.Load()).To(Equal(uint32(0)))
48+
Expect(st.Timeouts.Load()).To(Equal(uint32(0)))
4949
Expect(st.TotalConns).To(Equal(uint32(1)))
5050
Expect(st.IdleConns).To(Equal(uint32(0)))
5151
})

tx.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type Tx struct {
3333
stmtsMu sync.Mutex
3434
stmts []*Stmt
3535

36-
_closed int32
36+
_closed atomic.Bool
3737
}
3838

3939
var _ orm.DB = (*Tx)(nil)
@@ -368,7 +368,7 @@ func (tx *Tx) CloseContext(ctx context.Context) error {
368368
}
369369

370370
func (tx *Tx) close() {
371-
if !atomic.CompareAndSwapInt32(&tx._closed, 0, 1) {
371+
if !tx._closed.CompareAndSwap(false, true) {
372372
return
373373
}
374374

@@ -384,5 +384,5 @@ func (tx *Tx) close() {
384384
}
385385

386386
func (tx *Tx) closed() bool {
387-
return atomic.LoadInt32(&tx._closed) == 1
387+
return tx._closed.Load()
388388
}

0 commit comments

Comments
 (0)