@@ -44,7 +44,7 @@ use crate::{
44
44
transaction:: TxStatus ,
45
45
BinaryProtocol , Queryable , TextProtocol ,
46
46
} ,
47
- BinlogStream , InfileData , OptsBuilder ,
47
+ BinlogStream , ChangeUserOpts , InfileData , OptsBuilder ,
48
48
} ;
49
49
50
50
use self :: routines:: Routine ;
@@ -102,13 +102,15 @@ struct ConnInner {
102
102
pool : Option < Pool > ,
103
103
pending_result : std:: result:: Result < Option < PendingResult > , ServerError > ,
104
104
tx_status : TxStatus ,
105
+ reset_upon_returning_to_a_pool : bool ,
105
106
opts : Opts ,
106
107
last_io : Instant ,
107
108
wait_timeout : Duration ,
108
109
stmt_cache : StmtCache ,
109
110
nonce : Vec < u8 > ,
110
111
auth_plugin : AuthPlugin < ' static > ,
111
112
auth_switched : bool ,
113
+ server_key : Option < Vec < u8 > > ,
112
114
/// Connection is already disconnected.
113
115
pub ( crate ) disconnected : bool ,
114
116
/// One-time connection-level infile handler.
@@ -126,6 +128,8 @@ impl fmt::Debug for ConnInner {
126
128
. field ( "tx_status" , & self . tx_status )
127
129
. field ( "stream" , & self . stream )
128
130
. field ( "options" , & self . opts )
131
+ . field ( "server_key" , & self . server_key )
132
+ . field ( "auth_plugin" , & self . auth_plugin )
129
133
. finish ( )
130
134
}
131
135
}
@@ -154,7 +158,9 @@ impl ConnInner {
154
158
auth_plugin : AuthPlugin :: MysqlNativePassword ,
155
159
auth_switched : false ,
156
160
disconnected : false ,
161
+ server_key : None ,
157
162
infile_handler : None ,
163
+ reset_upon_returning_to_a_pool : false ,
158
164
}
159
165
}
160
166
@@ -416,16 +422,33 @@ impl Conn {
416
422
/// Returns true if io stream is encrypted.
417
423
fn is_secure ( & self ) -> bool {
418
424
#[ cfg( any( feature = "native-tls-tls" , feature = "rustls-tls" ) ) ]
419
- if let Some ( ref stream) = self . inner . stream {
420
- stream. is_secure ( )
421
- } else {
422
- false
425
+ {
426
+ self . inner
427
+ . stream
428
+ . as_ref ( )
429
+ . map ( |x| x. is_secure ( ) )
430
+ . unwrap_or_default ( )
423
431
}
424
432
425
433
#[ cfg( not( any( feature = "native-tls-tls" , feature = "rustls-tls" ) ) ) ]
426
434
false
427
435
}
428
436
437
+ /// Returns true if io stream is socket.
438
+ fn is_socket ( & self ) -> bool {
439
+ #[ cfg( unix) ]
440
+ {
441
+ self . inner
442
+ . stream
443
+ . as_ref ( )
444
+ . map ( |x| x. is_socket ( ) )
445
+ . unwrap_or_default ( )
446
+ }
447
+
448
+ #[ cfg( not( unix) ) ]
449
+ false
450
+ }
451
+
429
452
/// Hacky way to move connection through &mut. `self` becomes unusable.
430
453
fn take ( & mut self ) -> Conn {
431
454
mem:: replace ( self , Conn :: empty ( Default :: default ( ) ) )
@@ -663,16 +686,21 @@ impl Conn {
663
686
let mut pass = crate :: BUFFER_POOL . get_with ( pass. as_bytes ( ) ) ;
664
687
pass. as_mut ( ) . push ( 0 ) ;
665
688
666
- if self . is_secure ( ) {
689
+ if self . is_secure ( ) || self . is_socket ( ) {
667
690
self . write_packet ( pass) . await ?;
668
691
} else {
669
- self . write_bytes ( & [ 0x02 ] [ ..] ) . await ?;
670
- let packet = self . read_packet ( ) . await ?;
671
- let key = & packet[ 1 ..] ;
692
+ if self . inner . server_key . is_none ( ) {
693
+ self . write_bytes ( & [ 0x02 ] [ ..] ) . await ?;
694
+ let packet = self . read_packet ( ) . await ?;
695
+ self . inner . server_key = Some ( packet[ 1 ..] . to_vec ( ) ) ;
696
+ }
672
697
for ( i, byte) in pass. as_mut ( ) . iter_mut ( ) . enumerate ( ) {
673
698
* byte ^= self . inner . nonce [ i % self . inner . nonce . len ( ) ] ;
674
699
}
675
- let encrypted_pass = crypto:: encrypt ( & * pass, key) ;
700
+ let encrypted_pass = crypto:: encrypt (
701
+ & * pass,
702
+ self . inner . server_key . as_deref ( ) . expect ( "unreachable" ) ,
703
+ ) ;
676
704
self . write_bytes ( & * encrypted_pass) . await ?;
677
705
} ;
678
706
self . drop_packet ( ) . await ?;
@@ -958,12 +986,13 @@ impl Conn {
958
986
self . inner . last_io . elapsed ( )
959
987
}
960
988
961
- /// Executes `COM_RESET_CONNECTION` on `self` .
989
+ /// Executes [ `COM_RESET_CONNECTION`][1] .
962
990
///
963
- /// If server version is older than 5.7.2, then it'll reconnect.
964
- pub async fn reset ( & mut self ) -> Result < ( ) > {
965
- let pool = self . inner . pool . clone ( ) ;
966
-
991
+ /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
992
+ /// For older versions consider using [`Conn::change_user`].
993
+ ///
994
+ /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
995
+ pub async fn reset ( & mut self ) -> Result < bool > {
967
996
let supports_com_reset_connection = if self . inner . is_mariadb {
968
997
self . inner . version >= ( 10 , 2 , 4 )
969
998
} else {
@@ -973,19 +1002,62 @@ impl Conn {
973
1002
974
1003
if supports_com_reset_connection {
975
1004
self . routine ( routines:: ResetRoutine ) . await ?;
976
- } else {
977
- let opts = self . inner . opts . clone ( ) ;
978
- let old_conn = std :: mem :: replace ( self , Conn :: new ( opts ) . await ? ) ;
979
- // tidy up the old connection
980
- old_conn . close_conn ( ) . await ? ;
981
- } ;
1005
+ self . inner . stmt_cache . clear ( ) ;
1006
+ self . inner . infile_handler = None ;
1007
+ }
1008
+
1009
+ Ok ( supports_com_reset_connection )
1010
+ }
982
1011
1012
+ /// Executes [`COM_CHANGE_USER`][1].
1013
+ ///
1014
+ /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
1015
+ /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
1016
+ ///
1017
+ /// ## Note
1018
+ ///
1019
+ /// * Using non-default `opts` for a pooled connection is discouraging.
1020
+ /// * Connection options will be permanently updated.
1021
+ ///
1022
+ /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
1023
+ pub async fn change_user ( & mut self , opts : ChangeUserOpts ) -> Result < ( ) > {
1024
+ // We'll kick this connection from a pool if opts are changed.
1025
+ if opts != ChangeUserOpts :: default ( ) {
1026
+ let mut opts_changed = false ;
1027
+ if let Some ( user) = opts. user ( ) {
1028
+ opts_changed |= user != self . opts ( ) . user ( )
1029
+ } ;
1030
+ if let Some ( pass) = opts. pass ( ) {
1031
+ opts_changed |= pass != self . opts ( ) . pass ( )
1032
+ } ;
1033
+ if let Some ( db_name) = opts. db_name ( ) {
1034
+ opts_changed |= db_name != self . opts ( ) . db_name ( )
1035
+ } ;
1036
+ if opts_changed {
1037
+ if let Some ( pool) = self . inner . pool . take ( ) {
1038
+ pool. cancel_connection ( ) ;
1039
+ }
1040
+ }
1041
+ }
1042
+
1043
+ let conn_opts = & mut self . inner . opts ;
1044
+ opts. update_opts ( conn_opts) ;
1045
+ self . routine ( routines:: ChangeUser ) . await ?;
983
1046
self . inner . stmt_cache . clear ( ) ;
984
1047
self . inner . infile_handler = None ;
985
- self . inner . pool = pool;
986
1048
Ok ( ( ) )
987
1049
}
988
1050
1051
+ /// Resets the connection upon returning it to a pool.
1052
+ ///
1053
+ /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
1054
+ async fn reset_for_pool ( mut self ) -> Result < Self > {
1055
+ if !self . reset ( ) . await ? {
1056
+ self . change_user ( Default :: default ( ) ) . await ?;
1057
+ }
1058
+ Ok ( self )
1059
+ }
1060
+
989
1061
/// Requires that `self.inner.tx_status != TxStatus::None`
990
1062
async fn rollback_transaction ( & mut self ) -> Result < ( ) > {
991
1063
debug_assert_ne ! ( self . inner. tx_status, TxStatus :: None ) ;
@@ -1094,13 +1166,14 @@ mod test {
1094
1166
use bytes:: Bytes ;
1095
1167
use futures_util:: stream:: { self , StreamExt } ;
1096
1168
use mysql_common:: { binlog:: events:: EventData , constants:: MAX_PAYLOAD_LEN } ;
1169
+ use rand:: Fill ;
1097
1170
use tokio:: time:: timeout;
1098
1171
1099
1172
use std:: time:: Duration ;
1100
1173
1101
1174
use crate :: {
1102
- from_row, params, prelude:: * , test_misc:: get_opts, BinlogDumpFlags , BinlogRequest , Conn ,
1103
- Error , OptsBuilder , Pool , WhiteListFsHandler ,
1175
+ from_row, params, prelude:: * , test_misc:: get_opts, BinlogDumpFlags , BinlogRequest ,
1176
+ ChangeUserOpts , Conn , Error , OptsBuilder , Pool , Value , WhiteListFsHandler ,
1104
1177
} ;
1105
1178
1106
1179
async fn gen_dummy_data ( ) -> super :: Result < ( ) > {
@@ -1471,9 +1544,115 @@ mod test {
1471
1544
#[ tokio:: test]
1472
1545
async fn should_reset_the_connection ( ) -> super :: Result < ( ) > {
1473
1546
let mut conn = Conn :: new ( get_opts ( ) ) . await ?;
1474
- conn. exec_drop ( "SELECT ?" , ( 1_u8 , ) ) . await ?;
1475
- conn. reset ( ) . await ?;
1476
- conn. exec_drop ( "SELECT ?" , ( 1_u8 , ) ) . await ?;
1547
+ let max_execution_time = conn
1548
+ . query_first :: < u64 , _ > ( "SELECT @@max_execution_time" )
1549
+ . await ?
1550
+ . unwrap ( ) ;
1551
+
1552
+ conn. exec_drop (
1553
+ "SET SESSION max_execution_time = ?" ,
1554
+ ( max_execution_time + 1 , ) ,
1555
+ )
1556
+ . await ?;
1557
+
1558
+ assert_eq ! (
1559
+ conn. query_first:: <u64 , _>( "SELECT @@max_execution_time" )
1560
+ . await ?,
1561
+ Some ( max_execution_time + 1 )
1562
+ ) ;
1563
+
1564
+ if conn. reset ( ) . await ? {
1565
+ assert_eq ! (
1566
+ conn. query_first:: <u64 , _>( "SELECT @@max_execution_time" )
1567
+ . await ?,
1568
+ Some ( max_execution_time)
1569
+ ) ;
1570
+ } else {
1571
+ assert_eq ! (
1572
+ conn. query_first:: <u64 , _>( "SELECT @@max_execution_time" )
1573
+ . await ?,
1574
+ Some ( max_execution_time + 1 )
1575
+ ) ;
1576
+ }
1577
+
1578
+ conn. disconnect ( ) . await ?;
1579
+ Ok ( ( ) )
1580
+ }
1581
+
1582
+ #[ tokio:: test]
1583
+ async fn should_change_user ( ) -> super :: Result < ( ) > {
1584
+ let mut conn = Conn :: new ( get_opts ( ) ) . await ?;
1585
+ let max_execution_time = conn
1586
+ . query_first :: < u64 , _ > ( "SELECT @@max_execution_time" )
1587
+ . await ?
1588
+ . unwrap ( ) ;
1589
+
1590
+ conn. exec_drop (
1591
+ "SET SESSION max_execution_time = ?" ,
1592
+ ( max_execution_time + 1 , ) ,
1593
+ )
1594
+ . await ?;
1595
+
1596
+ assert_eq ! (
1597
+ conn. query_first:: <u64 , _>( "SELECT @@max_execution_time" )
1598
+ . await ?,
1599
+ Some ( max_execution_time + 1 )
1600
+ ) ;
1601
+
1602
+ conn. change_user ( Default :: default ( ) ) . await ?;
1603
+ assert_eq ! (
1604
+ conn. query_first:: <u64 , _>( "SELECT @@max_execution_time" )
1605
+ . await ?,
1606
+ Some ( max_execution_time)
1607
+ ) ;
1608
+
1609
+ let plugins: & [ & str ] = if !conn. inner . is_mariadb && conn. server_version ( ) >= ( 5 , 8 , 0 ) {
1610
+ & [ "mysql_native_password" , "caching_sha2_password" ]
1611
+ } else {
1612
+ & [ "mysql_native_password" ]
1613
+ } ;
1614
+
1615
+ for plugin in plugins {
1616
+ let mut conn2 = Conn :: new ( get_opts ( ) ) . await . unwrap ( ) ;
1617
+
1618
+ let mut rng = rand:: thread_rng ( ) ;
1619
+ let mut pass = [ 0u8 ; 10 ] ;
1620
+ pass. try_fill ( & mut rng) . unwrap ( ) ;
1621
+ let pass: String = IntoIterator :: into_iter ( pass)
1622
+ . map ( |x| ( ( x % ( 123 - 97 ) ) + 97 ) as char )
1623
+ . collect ( ) ;
1624
+ conn. query_drop ( "DROP USER IF EXISTS __mysql_async_test_user" )
1625
+ . await
1626
+ . unwrap ( ) ;
1627
+ conn. query_drop ( format ! (
1628
+ "CREATE USER '__mysql_async_test_user'@'%' IDENTIFIED WITH {} BY {}" ,
1629
+ plugin,
1630
+ Value :: from( pass. clone( ) ) . as_sql( false )
1631
+ ) )
1632
+ . await
1633
+ . unwrap ( ) ;
1634
+ conn. query_drop ( "FLUSH PRIVILEGES" ) . await . unwrap ( ) ;
1635
+
1636
+ conn2
1637
+ . change_user (
1638
+ ChangeUserOpts :: default ( )
1639
+ . with_db_name ( None )
1640
+ . with_user ( Some ( "__mysql_async_test_user" . into ( ) ) )
1641
+ . with_pass ( Some ( pass) ) ,
1642
+ )
1643
+ . await
1644
+ . unwrap ( ) ;
1645
+ assert_eq ! (
1646
+ conn2
1647
+ . query_first:: <( Option <String >, String ) , _>( "SELECT DATABASE(), USER();" )
1648
+ . await
1649
+ . unwrap( ) ,
1650
+ Some ( ( None , String :: from( "__mysql_async_test_user@localhost" ) ) ) ,
1651
+ ) ;
1652
+
1653
+ conn2. disconnect ( ) . await . unwrap ( ) ;
1654
+ }
1655
+
1477
1656
conn. disconnect ( ) . await ?;
1478
1657
Ok ( ( ) )
1479
1658
}
0 commit comments