@@ -102,6 +102,7 @@ struct ConnInner {
102
102
status : StatusFlags ,
103
103
last_ok_packet : Option < OkPacket < ' static > > ,
104
104
last_err_packet : Option < mysql_common:: packets:: ServerError < ' static > > ,
105
+ handshake_complete : bool ,
105
106
pool : Option < Pool > ,
106
107
pending_result : std:: result:: Result < Option < PendingResult > , ServerError > ,
107
108
tx_status : TxStatus ,
@@ -147,6 +148,7 @@ impl ConnInner {
147
148
status : StatusFlags :: empty ( ) ,
148
149
last_ok_packet : None ,
149
150
last_err_packet : None ,
151
+ handshake_complete : false ,
150
152
stream : None ,
151
153
is_mariadb : false ,
152
154
version : ( 0 , 0 , 0 ) ,
@@ -581,10 +583,11 @@ impl Conn {
581
583
) ;
582
584
583
585
// Serialize here to satisfy borrow checker.
584
- let mut buf = crate :: BUFFER_POOL . get ( ) ;
586
+ let mut buf = crate :: buffer_pool ( ) . get ( ) ;
585
587
handshake_response. serialize ( buf. as_mut ( ) ) ;
586
588
587
589
self . write_packet ( buf) . await ?;
590
+ self . inner . handshake_complete = true ;
588
591
Ok ( ( ) )
589
592
}
590
593
@@ -633,7 +636,7 @@ impl Conn {
633
636
if let Some ( plugin_data) = plugin_data {
634
637
self . write_struct ( & plugin_data. into_owned ( ) ) . await ?;
635
638
} else {
636
- self . write_packet ( crate :: BUFFER_POOL . get ( ) ) . await ?;
639
+ self . write_packet ( crate :: buffer_pool ( ) . get ( ) ) . await ?;
637
640
}
638
641
639
642
self . continue_auth ( ) . await ?;
@@ -701,7 +704,7 @@ impl Conn {
701
704
}
702
705
Some ( 0x04 ) => {
703
706
let pass = self . inner . opts . pass ( ) . unwrap_or_default ( ) ;
704
- let mut pass = crate :: BUFFER_POOL . get_with ( pass. as_bytes ( ) ) ;
707
+ let mut pass = crate :: buffer_pool ( ) . get_with ( pass. as_bytes ( ) ) ;
705
708
pass. as_mut ( ) . push ( 0 ) ;
706
709
707
710
if self . is_secure ( ) || self . is_socket ( ) {
@@ -789,7 +792,19 @@ impl Conn {
789
792
if let Ok ( ok_packet) = ok_packet {
790
793
self . handle_ok ( ok_packet. into_owned ( ) ) ;
791
794
} else {
792
- let err_packet = ParseBuf ( packet) . parse :: < ErrPacket > ( self . capabilities ( ) ) ;
795
+ // If we haven't completed the handshake the server will not be aware of our
796
+ // capabilities and so it will behave as if we have none. In particular, the error
797
+ // packet will not contain a SQL State field even if our capabilities do contain the
798
+ // `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet
799
+ // with no capability assumptions if we have not completed the handshake.
800
+ //
801
+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
802
+ let capabilities = if self . inner . handshake_complete {
803
+ self . capabilities ( )
804
+ } else {
805
+ CapabilityFlags :: empty ( )
806
+ } ;
807
+ let err_packet = ParseBuf ( packet) . parse :: < ErrPacket > ( capabilities) ;
793
808
if let Ok ( err_packet) = err_packet {
794
809
self . handle_err ( err_packet) ?;
795
810
return Ok ( true ) ;
@@ -838,13 +853,13 @@ impl Conn {
838
853
839
854
/// Writes bytes to a server.
840
855
pub ( crate ) async fn write_bytes ( & mut self , bytes : & [ u8 ] ) -> Result < ( ) > {
841
- let buf = crate :: BUFFER_POOL . get_with ( bytes) ;
856
+ let buf = crate :: buffer_pool ( ) . get_with ( bytes) ;
842
857
self . write_packet ( buf) . await
843
858
}
844
859
845
860
/// Sends a serializable structure to a server.
846
861
pub ( crate ) async fn write_struct < T : MySerialize > ( & mut self , x : & T ) -> Result < ( ) > {
847
- let mut buf = crate :: BUFFER_POOL . get ( ) ;
862
+ let mut buf = crate :: buffer_pool ( ) . get ( ) ;
848
863
x. serialize ( buf. as_mut ( ) ) ;
849
864
self . write_packet ( buf) . await
850
865
}
@@ -870,7 +885,7 @@ impl Conn {
870
885
T : AsRef < [ u8 ] > ,
871
886
{
872
887
let cmd_data = cmd_data. as_ref ( ) ;
873
- let mut buf = crate :: BUFFER_POOL . get ( ) ;
888
+ let mut buf = crate :: buffer_pool ( ) . get ( ) ;
874
889
let body = buf. as_mut ( ) ;
875
890
body. push ( cmd as u8 ) ;
876
891
body. extend_from_slice ( cmd_data) ;
@@ -1270,10 +1285,11 @@ mod test {
1270
1285
use futures_util:: stream:: { self , StreamExt } ;
1271
1286
use mysql_common:: constants:: MAX_PAYLOAD_LEN ;
1272
1287
use rand:: Fill ;
1288
+ use tokio:: { io:: AsyncWriteExt , net:: TcpListener } ;
1273
1289
1274
1290
use crate :: {
1275
1291
from_row, params, prelude:: * , test_misc:: get_opts, ChangeUserOpts , Conn , Error ,
1276
- OptsBuilder , Pool , Value , WhiteListFsHandler ,
1292
+ OptsBuilder , Pool , ServerError , Value , WhiteListFsHandler ,
1277
1293
} ;
1278
1294
1279
1295
#[ tokio:: test]
@@ -1400,16 +1416,18 @@ mod test {
1400
1416
. filter ( |variant| plugins. iter ( ) . any ( |p| p == variant. 0 ) ) ;
1401
1417
1402
1418
for ( plug, val, pass) in variants {
1419
+ dbg ! ( ( plug, val, pass, conn. inner. version) ) ;
1420
+
1421
+ if plug == "mysql_native_password" && conn. inner . version >= ( 9 , 0 , 0 ) {
1422
+ continue ;
1423
+ }
1424
+
1403
1425
let _ = conn. query_drop ( "DROP USER 'test_user'@'%'" ) . await ;
1404
1426
1405
1427
let query = format ! ( "CREATE USER 'test_user'@'%' IDENTIFIED WITH {}" , plug) ;
1406
1428
conn. query_drop ( query) . await . unwrap ( ) ;
1407
1429
1408
- if ( 8 , 0 , 11 ) <= conn. inner . version && conn. inner . version <= ( 9 , 0 , 0 ) {
1409
- conn. query_drop ( format ! ( "SET PASSWORD FOR 'test_user'@'%' = '{}'" , pass) )
1410
- . await
1411
- . unwrap ( ) ;
1412
- } else {
1430
+ if conn. inner . version <= ( 8 , 0 , 11 ) {
1413
1431
conn. query_drop ( format ! ( "SET old_passwords = {}" , val) )
1414
1432
. await
1415
1433
. unwrap ( ) ;
@@ -1419,6 +1437,10 @@ mod test {
1419
1437
) )
1420
1438
. await
1421
1439
. unwrap ( ) ;
1440
+ } else {
1441
+ conn. query_drop ( format ! ( "SET PASSWORD FOR 'test_user'@'%' = '{}'" , pass) )
1442
+ . await
1443
+ . unwrap ( ) ;
1422
1444
} ;
1423
1445
1424
1446
let opts = get_opts ( )
@@ -1548,6 +1570,10 @@ mod test {
1548
1570
} ;
1549
1571
1550
1572
for ( i, plugin) in plugins. iter ( ) . enumerate ( ) {
1573
+ if * plugin == "mysql_native_password" && conn. server_version ( ) >= ( 9 , 0 , 0 ) {
1574
+ continue ;
1575
+ }
1576
+
1551
1577
let mut rng = rand:: thread_rng ( ) ;
1552
1578
let mut pass = [ 0u8 ; 10 ] ;
1553
1579
pass. try_fill ( & mut rng) . unwrap ( ) ;
@@ -2189,6 +2215,45 @@ mod test {
2189
2215
Ok ( ( ) )
2190
2216
}
2191
2217
2218
+ #[ tokio:: test]
2219
+ async fn should_handle_initial_error_packet ( ) {
2220
+ let header = [
2221
+ 0x68 , 0x00 , 0x00 , // packet_length
2222
+ 0x00 , // sequence
2223
+ 0xff , // error_header
2224
+ 0x69 , 0x04 , // error_code
2225
+ ] ;
2226
+ let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'" ;
2227
+
2228
+ // Create a fake MySQL server that immediately replies with an error packet.
2229
+ let listener = TcpListener :: bind ( "127.0.0.1:0000" ) . await . unwrap ( ) ;
2230
+
2231
+ let listen_addr = listener. local_addr ( ) . unwrap ( ) ;
2232
+
2233
+ tokio:: task:: spawn ( async move {
2234
+ let ( mut stream, _) = listener. accept ( ) . await . unwrap ( ) ;
2235
+ stream. write_all ( & header) . await . unwrap ( ) ;
2236
+ stream. write_all ( error_message. as_bytes ( ) ) . await . unwrap ( ) ;
2237
+ stream. shutdown ( ) . await . unwrap ( ) ;
2238
+ } ) ;
2239
+
2240
+ let opts = OptsBuilder :: default ( )
2241
+ . ip_or_hostname ( listen_addr. ip ( ) . to_string ( ) )
2242
+ . tcp_port ( listen_addr. port ( ) ) ;
2243
+ let server_err = match Conn :: new ( opts) . await {
2244
+ Err ( Error :: Server ( server_err) ) => server_err,
2245
+ other => panic ! ( "expected server error but got: {:?}" , other) ,
2246
+ } ;
2247
+ assert_eq ! (
2248
+ server_err,
2249
+ ServerError {
2250
+ code: 1129 ,
2251
+ state: "HY000" . to_owned( ) ,
2252
+ message: error_message. to_owned( ) ,
2253
+ }
2254
+ ) ;
2255
+ }
2256
+
2192
2257
#[ cfg( feature = "nightly" ) ]
2193
2258
mod bench {
2194
2259
use crate :: { conn:: Conn , queryable:: Queryable , test_misc:: get_opts} ;
0 commit comments