7
7
// modified, or distributed except according to those terms.
8
8
9
9
use futures_util:: FutureExt ;
10
+ use priority_queue:: PriorityQueue ;
10
11
use tokio:: sync:: mpsc;
11
12
12
13
use std:: {
13
14
cmp:: { Ordering , Reverse } ,
14
- collections:: { BinaryHeap , VecDeque } ,
15
+ collections:: VecDeque ,
15
16
convert:: TryFrom ,
17
+ hash:: { Hash , Hasher } ,
16
18
pin:: Pin ,
17
19
str:: FromStr ,
18
20
sync:: { atomic, Arc , Mutex } ,
@@ -63,7 +65,7 @@ impl From<Conn> for IdlingConn {
63
65
/// This is fine as long as we never do expensive work while holding the lock!
64
66
#[ derive( Debug ) ]
65
67
struct Exchange {
66
- waiting : BinaryHeap < QueuedWaker > ,
68
+ waiting : Waitlist ,
67
69
available : VecDeque < IdlingConn > ,
68
70
exist : usize ,
69
71
// only used to spawn the recycler the first time we're in async context
@@ -88,9 +90,45 @@ impl Exchange {
88
90
}
89
91
}
90
92
93
+ #[ derive( Default , Debug ) ]
94
+ struct Waitlist {
95
+ queue : PriorityQueue < QueuedWaker , QueueId > ,
96
+ }
97
+
98
+ impl Waitlist {
99
+ fn push ( & mut self , w : Waker , queue_id : QueueId ) {
100
+ self . queue . push (
101
+ QueuedWaker {
102
+ queue_id,
103
+ waker : Some ( w) ,
104
+ } ,
105
+ queue_id,
106
+ ) ;
107
+ }
108
+
109
+ fn pop ( & mut self ) -> Option < Waker > {
110
+ match self . queue . pop ( ) {
111
+ Some ( ( qw, _) ) => Some ( qw. waker . unwrap ( ) ) ,
112
+ None => None ,
113
+ }
114
+ }
115
+
116
+ fn remove ( & mut self , id : QueueId ) {
117
+ let tmp = QueuedWaker {
118
+ queue_id : id,
119
+ waker : None ,
120
+ } ;
121
+ self . queue . remove ( & tmp) ;
122
+ }
123
+
124
+ fn is_empty ( & self ) -> bool {
125
+ self . queue . is_empty ( )
126
+ }
127
+ }
128
+
91
129
const QUEUE_END_ID : QueueId = QueueId ( Reverse ( u64:: MAX ) ) ;
92
130
93
- #[ derive( Debug , Copy , Clone , Eq , PartialEq , Ord , PartialOrd ) ]
131
+ #[ derive( Debug , Copy , Clone , Eq , PartialEq , Ord , PartialOrd , Hash ) ]
94
132
pub ( crate ) struct QueueId ( Reverse < u64 > ) ;
95
133
96
134
impl QueueId {
@@ -104,13 +142,7 @@ impl QueueId {
104
142
#[ derive( Debug ) ]
105
143
struct QueuedWaker {
106
144
queue_id : QueueId ,
107
- waker : Waker ,
108
- }
109
-
110
- impl QueuedWaker {
111
- fn new ( queue_id : QueueId , waker : Waker ) -> Self {
112
- QueuedWaker { queue_id, waker }
113
- }
145
+ waker : Option < Waker > ,
114
146
}
115
147
116
148
impl Eq for QueuedWaker { }
@@ -133,6 +165,12 @@ impl PartialOrd for QueuedWaker {
133
165
}
134
166
}
135
167
168
+ impl Hash for QueuedWaker {
169
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
170
+ self . queue_id . hash ( state)
171
+ }
172
+ }
173
+
136
174
/// Connection pool data.
137
175
#[ derive( Debug ) ]
138
176
pub struct Inner {
@@ -177,7 +215,7 @@ impl Pool {
177
215
closed : false . into ( ) ,
178
216
exchange : Mutex :: new ( Exchange {
179
217
available : VecDeque :: with_capacity ( pool_opts. constraints ( ) . max ( ) ) ,
180
- waiting : BinaryHeap :: new ( ) ,
218
+ waiting : Waitlist :: default ( ) ,
181
219
exist : 0 ,
182
220
recycler : Some ( ( rx, pool_opts) ) ,
183
221
} ) ,
@@ -227,8 +265,8 @@ impl Pool {
227
265
let mut exchange = self . inner . exchange . lock ( ) . unwrap ( ) ;
228
266
if exchange. available . len ( ) < self . opts . pool_opts ( ) . active_bound ( ) {
229
267
exchange. available . push_back ( conn. into ( ) ) ;
230
- if let Some ( qw ) = exchange. waiting . pop ( ) {
231
- qw . waker . wake ( ) ;
268
+ if let Some ( w ) = exchange. waiting . pop ( ) {
269
+ w . wake ( ) ;
232
270
}
233
271
return ;
234
272
}
@@ -262,8 +300,8 @@ impl Pool {
262
300
let mut exchange = self . inner . exchange . lock ( ) . unwrap ( ) ;
263
301
exchange. exist -= 1 ;
264
302
// we just enabled the creation of a new connection!
265
- if let Some ( qw ) = exchange. waiting . pop ( ) {
266
- qw . waker . wake ( ) ;
303
+ if let Some ( w ) = exchange. waiting . pop ( ) {
304
+ w . wake ( ) ;
267
305
}
268
306
}
269
307
@@ -296,9 +334,7 @@ impl Pool {
296
334
297
335
// Check if others are waiting and we're not queued.
298
336
if !exchange. waiting . is_empty ( ) && !queued {
299
- exchange
300
- . waiting
301
- . push ( QueuedWaker :: new ( queue_id, cx. waker ( ) . clone ( ) ) ) ;
337
+ exchange. waiting . push ( cx. waker ( ) . clone ( ) , queue_id) ;
302
338
return Poll :: Pending ;
303
339
}
304
340
@@ -328,11 +364,14 @@ impl Pool {
328
364
}
329
365
330
366
// Polled, but no conn available? Back into the queue.
331
- exchange
332
- . waiting
333
- . push ( QueuedWaker :: new ( queue_id, cx. waker ( ) . clone ( ) ) ) ;
367
+ exchange. waiting . push ( cx. waker ( ) . clone ( ) , queue_id) ;
334
368
Poll :: Pending
335
369
}
370
+
371
+ fn unqueue ( & self , queue_id : QueueId ) {
372
+ let mut exchange = self . inner . exchange . lock ( ) . unwrap ( ) ;
373
+ exchange. waiting . remove ( queue_id) ;
374
+ }
336
375
}
337
376
338
377
impl Drop for Conn {
@@ -363,12 +402,20 @@ mod test {
363
402
try_join, FutureExt ,
364
403
} ;
365
404
use mysql_common:: row:: Row ;
366
- use tokio:: time:: sleep;
405
+ use tokio:: time:: { sleep, timeout } ;
367
406
368
- use std:: time:: Duration ;
407
+ use std:: {
408
+ cmp:: Reverse ,
409
+ task:: { RawWaker , RawWakerVTable , Waker } ,
410
+ time:: Duration ,
411
+ } ;
369
412
370
413
use crate :: {
371
- conn:: pool:: Pool , opts:: PoolOpts , prelude:: * , test_misc:: get_opts, PoolConstraints , TxOpts ,
414
+ conn:: pool:: { Pool , QueueId , Waitlist , QUEUE_END_ID } ,
415
+ opts:: PoolOpts ,
416
+ prelude:: * ,
417
+ test_misc:: get_opts,
418
+ PoolConstraints , TxOpts ,
372
419
} ;
373
420
374
421
macro_rules! conn_ex_field {
@@ -824,6 +871,27 @@ mod test {
824
871
Ok ( ( ) )
825
872
}
826
873
874
+ #[ tokio:: test]
875
+ async fn should_remove_waker_of_cancelled_task ( ) {
876
+ let pool_constraints = PoolConstraints :: new ( 1 , 1 ) . unwrap ( ) ;
877
+ let pool_opts = PoolOpts :: default ( ) . with_constraints ( pool_constraints) ;
878
+
879
+ let pool = Pool :: new ( get_opts ( ) . pool_opts ( pool_opts) ) ;
880
+ let only_conn = pool. get_conn ( ) . await . unwrap ( ) ;
881
+
882
+ let join_handle = tokio:: spawn ( timeout ( Duration :: from_secs ( 1 ) , pool. get_conn ( ) ) ) ;
883
+
884
+ sleep ( Duration :: from_secs ( 2 ) ) . await ;
885
+
886
+ match join_handle. await . unwrap ( ) {
887
+ Err ( _elapsed) => ( ) ,
888
+ _ => panic ! ( "unexpected Ok()" ) ,
889
+ }
890
+ drop ( only_conn) ;
891
+
892
+ assert_eq ! ( 0 , pool. inner. exchange. lock( ) . unwrap( ) . waiting. queue. len( ) ) ;
893
+ }
894
+
827
895
#[ tokio:: test]
828
896
async fn should_work_if_pooled_connection_operation_is_cancelled ( ) -> super :: Result < ( ) > {
829
897
let pool = Pool :: new ( get_opts ( ) ) ;
@@ -868,6 +936,40 @@ mod test {
868
936
Ok ( ( ) )
869
937
}
870
938
939
+ #[ test]
940
+ fn waitlist_integrity ( ) {
941
+ const DATA : * const ( ) = & ( ) ;
942
+ const NOOP_CLONE_FN : unsafe fn ( * const ( ) ) -> RawWaker = |_| RawWaker :: new ( DATA , & RW_VTABLE ) ;
943
+ const NOOP_FN : unsafe fn ( * const ( ) ) = |_| { } ;
944
+ static RW_VTABLE : RawWakerVTable =
945
+ RawWakerVTable :: new ( NOOP_CLONE_FN , NOOP_FN , NOOP_FN , NOOP_FN ) ;
946
+ let w = unsafe { Waker :: from_raw ( RawWaker :: new ( DATA , & RW_VTABLE ) ) } ;
947
+
948
+ let mut waitlist = Waitlist :: default ( ) ;
949
+ assert_eq ! ( 0 , waitlist. queue. len( ) ) ;
950
+
951
+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 4 ) ) ) ;
952
+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 2 ) ) ) ;
953
+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 8 ) ) ) ;
954
+ waitlist. push ( w. clone ( ) , QUEUE_END_ID ) ;
955
+ waitlist. push ( w. clone ( ) , QueueId ( Reverse ( 10 ) ) ) ;
956
+
957
+ waitlist. remove ( QueueId ( Reverse ( 8 ) ) ) ;
958
+
959
+ assert_eq ! ( 4 , waitlist. queue. len( ) ) ;
960
+
961
+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
962
+ assert_eq ! ( 2 , id. 0 . 0 ) ;
963
+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
964
+ assert_eq ! ( 4 , id. 0 . 0 ) ;
965
+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
966
+ assert_eq ! ( 10 , id. 0 . 0 ) ;
967
+ let ( _, id) = waitlist. queue . pop ( ) . unwrap ( ) ;
968
+ assert_eq ! ( QUEUE_END_ID , id) ;
969
+
970
+ assert_eq ! ( 0 , waitlist. queue. len( ) ) ;
971
+ }
972
+
871
973
#[ cfg( feature = "nightly" ) ]
872
974
mod bench {
873
975
use futures_util:: future:: { FutureExt , TryFutureExt } ;
0 commit comments