Skip to content

Commit 134cbf8

Browse files
committed
Implement Conn::change_user
1 parent ad90c52 commit 134cbf8

File tree

9 files changed

+451
-54
lines changed

9 files changed

+451
-54
lines changed

src/conn/mod.rs

Lines changed: 206 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::{
4444
transaction::TxStatus,
4545
BinaryProtocol, Queryable, TextProtocol,
4646
},
47-
BinlogStream, InfileData, OptsBuilder,
47+
BinlogStream, ChangeUserOpts, InfileData, OptsBuilder,
4848
};
4949

5050
use self::routines::Routine;
@@ -102,13 +102,15 @@ struct ConnInner {
102102
pool: Option<Pool>,
103103
pending_result: std::result::Result<Option<PendingResult>, ServerError>,
104104
tx_status: TxStatus,
105+
reset_upon_returning_to_a_pool: bool,
105106
opts: Opts,
106107
last_io: Instant,
107108
wait_timeout: Duration,
108109
stmt_cache: StmtCache,
109110
nonce: Vec<u8>,
110111
auth_plugin: AuthPlugin<'static>,
111112
auth_switched: bool,
113+
server_key: Option<Vec<u8>>,
112114
/// Connection is already disconnected.
113115
pub(crate) disconnected: bool,
114116
/// One-time connection-level infile handler.
@@ -126,6 +128,8 @@ impl fmt::Debug for ConnInner {
126128
.field("tx_status", &self.tx_status)
127129
.field("stream", &self.stream)
128130
.field("options", &self.opts)
131+
.field("server_key", &self.server_key)
132+
.field("auth_plugin", &self.auth_plugin)
129133
.finish()
130134
}
131135
}
@@ -154,7 +158,9 @@ impl ConnInner {
154158
auth_plugin: AuthPlugin::MysqlNativePassword,
155159
auth_switched: false,
156160
disconnected: false,
161+
server_key: None,
157162
infile_handler: None,
163+
reset_upon_returning_to_a_pool: false,
158164
}
159165
}
160166

@@ -416,16 +422,33 @@ impl Conn {
416422
/// Returns true if io stream is encrypted.
417423
fn is_secure(&self) -> bool {
418424
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
419-
if let Some(ref stream) = self.inner.stream {
420-
stream.is_secure()
421-
} else {
422-
false
425+
{
426+
self.inner
427+
.stream
428+
.as_ref()
429+
.map(|x| x.is_secure())
430+
.unwrap_or_default()
423431
}
424432

425433
#[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
426434
false
427435
}
428436

437+
/// Returns true if io stream is socket.
438+
fn is_socket(&self) -> bool {
439+
#[cfg(unix)]
440+
{
441+
self.inner
442+
.stream
443+
.as_ref()
444+
.map(|x| x.is_socket())
445+
.unwrap_or_default()
446+
}
447+
448+
#[cfg(not(unix))]
449+
false
450+
}
451+
429452
/// Hacky way to move connection through &mut. `self` becomes unusable.
430453
fn take(&mut self) -> Conn {
431454
mem::replace(self, Conn::empty(Default::default()))
@@ -663,16 +686,21 @@ impl Conn {
663686
let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
664687
pass.as_mut().push(0);
665688

666-
if self.is_secure() {
689+
if self.is_secure() || self.is_socket() {
667690
self.write_packet(pass).await?;
668691
} else {
669-
self.write_bytes(&[0x02][..]).await?;
670-
let packet = self.read_packet().await?;
671-
let key = &packet[1..];
692+
if self.inner.server_key.is_none() {
693+
self.write_bytes(&[0x02][..]).await?;
694+
let packet = self.read_packet().await?;
695+
self.inner.server_key = Some(packet[1..].to_vec());
696+
}
672697
for (i, byte) in pass.as_mut().iter_mut().enumerate() {
673698
*byte ^= self.inner.nonce[i % self.inner.nonce.len()];
674699
}
675-
let encrypted_pass = crypto::encrypt(&*pass, key);
700+
let encrypted_pass = crypto::encrypt(
701+
&*pass,
702+
self.inner.server_key.as_deref().expect("unreachable"),
703+
);
676704
self.write_bytes(&*encrypted_pass).await?;
677705
};
678706
self.drop_packet().await?;
@@ -958,12 +986,13 @@ impl Conn {
958986
self.inner.last_io.elapsed()
959987
}
960988

961-
/// Executes `COM_RESET_CONNECTION` on `self`.
989+
/// Executes [`COM_RESET_CONNECTION`][1].
962990
///
963-
/// If server version is older than 5.7.2, then it'll reconnect.
964-
pub async fn reset(&mut self) -> Result<()> {
965-
let pool = self.inner.pool.clone();
966-
991+
/// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
992+
/// For older versions consider using [`Conn::change_user`].
993+
///
994+
/// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
995+
pub async fn reset(&mut self) -> Result<bool> {
967996
let supports_com_reset_connection = if self.inner.is_mariadb {
968997
self.inner.version >= (10, 2, 4)
969998
} else {
@@ -973,19 +1002,62 @@ impl Conn {
9731002

9741003
if supports_com_reset_connection {
9751004
self.routine(routines::ResetRoutine).await?;
976-
} else {
977-
let opts = self.inner.opts.clone();
978-
let old_conn = std::mem::replace(self, Conn::new(opts).await?);
979-
// tidy up the old connection
980-
old_conn.close_conn().await?;
981-
};
1005+
self.inner.stmt_cache.clear();
1006+
self.inner.infile_handler = None;
1007+
}
1008+
1009+
Ok(supports_com_reset_connection)
1010+
}
9821011

1012+
/// Executes [`COM_CHANGE_USER`][1].
1013+
///
1014+
/// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
1015+
/// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
1016+
///
1017+
/// ## Note
1018+
///
1019+
/// * Using non-default `opts` for a pooled connection is discouraging.
1020+
/// * Connection options will be permanently updated.
1021+
///
1022+
/// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
1023+
pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1024+
// We'll kick this connection from a pool if opts are changed.
1025+
if opts != ChangeUserOpts::default() {
1026+
let mut opts_changed = false;
1027+
if let Some(user) = opts.user() {
1028+
opts_changed |= user != self.opts().user()
1029+
};
1030+
if let Some(pass) = opts.pass() {
1031+
opts_changed |= pass != self.opts().pass()
1032+
};
1033+
if let Some(db_name) = opts.db_name() {
1034+
opts_changed |= db_name != self.opts().db_name()
1035+
};
1036+
if opts_changed {
1037+
if let Some(pool) = self.inner.pool.take() {
1038+
pool.cancel_connection();
1039+
}
1040+
}
1041+
}
1042+
1043+
let conn_opts = &mut self.inner.opts;
1044+
opts.update_opts(conn_opts);
1045+
self.routine(routines::ChangeUser).await?;
9831046
self.inner.stmt_cache.clear();
9841047
self.inner.infile_handler = None;
985-
self.inner.pool = pool;
9861048
Ok(())
9871049
}
9881050

1051+
/// Resets the connection upon returning it to a pool.
1052+
///
1053+
/// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
1054+
async fn reset_for_pool(mut self) -> Result<Self> {
1055+
if !self.reset().await? {
1056+
self.change_user(Default::default()).await?;
1057+
}
1058+
Ok(self)
1059+
}
1060+
9891061
/// Requires that `self.inner.tx_status != TxStatus::None`
9901062
async fn rollback_transaction(&mut self) -> Result<()> {
9911063
debug_assert_ne!(self.inner.tx_status, TxStatus::None);
@@ -1094,13 +1166,14 @@ mod test {
10941166
use bytes::Bytes;
10951167
use futures_util::stream::{self, StreamExt};
10961168
use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN};
1169+
use rand::Fill;
10971170
use tokio::time::timeout;
10981171

10991172
use std::time::Duration;
11001173

11011174
use crate::{
1102-
from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn,
1103-
Error, OptsBuilder, Pool, WhiteListFsHandler,
1175+
from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest,
1176+
ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler,
11041177
};
11051178

11061179
async fn gen_dummy_data() -> super::Result<()> {
@@ -1471,9 +1544,115 @@ mod test {
14711544
#[tokio::test]
14721545
async fn should_reset_the_connection() -> super::Result<()> {
14731546
let mut conn = Conn::new(get_opts()).await?;
1474-
conn.exec_drop("SELECT ?", (1_u8,)).await?;
1475-
conn.reset().await?;
1476-
conn.exec_drop("SELECT ?", (1_u8,)).await?;
1547+
let max_execution_time = conn
1548+
.query_first::<u64, _>("SELECT @@max_execution_time")
1549+
.await?
1550+
.unwrap();
1551+
1552+
conn.exec_drop(
1553+
"SET SESSION max_execution_time = ?",
1554+
(max_execution_time + 1,),
1555+
)
1556+
.await?;
1557+
1558+
assert_eq!(
1559+
conn.query_first::<u64, _>("SELECT @@max_execution_time")
1560+
.await?,
1561+
Some(max_execution_time + 1)
1562+
);
1563+
1564+
if conn.reset().await? {
1565+
assert_eq!(
1566+
conn.query_first::<u64, _>("SELECT @@max_execution_time")
1567+
.await?,
1568+
Some(max_execution_time)
1569+
);
1570+
} else {
1571+
assert_eq!(
1572+
conn.query_first::<u64, _>("SELECT @@max_execution_time")
1573+
.await?,
1574+
Some(max_execution_time + 1)
1575+
);
1576+
}
1577+
1578+
conn.disconnect().await?;
1579+
Ok(())
1580+
}
1581+
1582+
#[tokio::test]
1583+
async fn should_change_user() -> super::Result<()> {
1584+
let mut conn = Conn::new(get_opts()).await?;
1585+
let max_execution_time = conn
1586+
.query_first::<u64, _>("SELECT @@max_execution_time")
1587+
.await?
1588+
.unwrap();
1589+
1590+
conn.exec_drop(
1591+
"SET SESSION max_execution_time = ?",
1592+
(max_execution_time + 1,),
1593+
)
1594+
.await?;
1595+
1596+
assert_eq!(
1597+
conn.query_first::<u64, _>("SELECT @@max_execution_time")
1598+
.await?,
1599+
Some(max_execution_time + 1)
1600+
);
1601+
1602+
conn.change_user(Default::default()).await?;
1603+
assert_eq!(
1604+
conn.query_first::<u64, _>("SELECT @@max_execution_time")
1605+
.await?,
1606+
Some(max_execution_time)
1607+
);
1608+
1609+
let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
1610+
&["mysql_native_password", "caching_sha2_password"]
1611+
} else {
1612+
&["mysql_native_password"]
1613+
};
1614+
1615+
for plugin in plugins {
1616+
let mut conn2 = Conn::new(get_opts()).await.unwrap();
1617+
1618+
let mut rng = rand::thread_rng();
1619+
let mut pass = [0u8; 10];
1620+
pass.try_fill(&mut rng).unwrap();
1621+
let pass: String = IntoIterator::into_iter(pass)
1622+
.map(|x| ((x % (123 - 97)) + 97) as char)
1623+
.collect();
1624+
conn.query_drop("DROP USER IF EXISTS __mysql_async_test_user")
1625+
.await
1626+
.unwrap();
1627+
conn.query_drop(format!(
1628+
"CREATE USER '__mysql_async_test_user'@'%' IDENTIFIED WITH {} BY {}",
1629+
plugin,
1630+
Value::from(pass.clone()).as_sql(false)
1631+
))
1632+
.await
1633+
.unwrap();
1634+
conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1635+
1636+
conn2
1637+
.change_user(
1638+
ChangeUserOpts::default()
1639+
.with_db_name(None)
1640+
.with_user(Some("__mysql_async_test_user".into()))
1641+
.with_pass(Some(pass)),
1642+
)
1643+
.await
1644+
.unwrap();
1645+
assert_eq!(
1646+
conn2
1647+
.query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1648+
.await
1649+
.unwrap(),
1650+
Some((None, String::from("__mysql_async_test_user@localhost"))),
1651+
);
1652+
1653+
conn2.disconnect().await.unwrap();
1654+
}
1655+
14771656
conn.disconnect().await?;
14781657
Ok(())
14791658
}

src/conn/pool/futures/get_conn.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,18 @@ pub struct GetConn {
6969
pub(crate) queue_id: Option<QueueId>,
7070
pub(crate) pool: Option<Pool>,
7171
pub(crate) inner: GetConnInner,
72+
reset_upon_returning_to_a_pool: bool,
7273
#[cfg(feature = "tracing")]
7374
span: Arc<Span>,
7475
}
7576

7677
impl GetConn {
77-
pub(crate) fn new(pool: &Pool) -> GetConn {
78+
pub(crate) fn new(pool: &Pool, reset_upon_returning_to_a_pool: bool) -> GetConn {
7879
GetConn {
7980
queue_id: None,
8081
pool: Some(pool.clone()),
8182
inner: GetConnInner::New,
83+
reset_upon_returning_to_a_pool,
8284
#[cfg(feature = "tracing")]
8385
span: Arc::new(debug_span!("mysql_async::get_conn")),
8486
}
@@ -141,6 +143,8 @@ impl Future for GetConn {
141143
return match result {
142144
Ok(mut c) => {
143145
c.inner.pool = Some(pool);
146+
c.inner.reset_upon_returning_to_a_pool =
147+
self.reset_upon_returning_to_a_pool;
144148
Poll::Ready(Ok(c))
145149
}
146150
Err(e) => {
@@ -152,12 +156,14 @@ impl Future for GetConn {
152156
GetConnInner::Checking(ref mut f) => {
153157
let result = ready!(Pin::new(f).poll(cx));
154158
match result {
155-
Ok(mut checked_conn) => {
159+
Ok(mut c) => {
156160
self.inner = GetConnInner::Done;
157161

158162
let pool = self.pool_take();
159-
checked_conn.inner.pool = Some(pool);
160-
return Poll::Ready(Ok(checked_conn));
163+
c.inner.pool = Some(pool);
164+
c.inner.reset_upon_returning_to_a_pool =
165+
self.reset_upon_returning_to_a_pool;
166+
return Poll::Ready(Ok(c));
161167
}
162168
Err(_) => {
163169
// Idling connection is broken. We'll drop it and try again.

0 commit comments

Comments
 (0)