Skip to content

Commit 0d62e1c

Browse files
authored
Merge branch 'master' into update-rustls
2 parents 3e8868f + 129b8d8 commit 0d62e1c

File tree

5 files changed

+101
-44
lines changed

5 files changed

+101
-44
lines changed

Cargo.toml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ license = "MIT/Apache-2.0"
77
name = "mysql_async"
88
readme = "README.md"
99
repository = "https://github.com/blackbeam/mysql_async"
10-
version = "0.34.1"
10+
version = "0.34.2"
1111
exclude = ["test/*"]
1212
edition = "2021"
1313
categories = ["asynchronous", "database"]
@@ -20,25 +20,22 @@ futures-core = "0.3"
2020
futures-util = "0.3"
2121
futures-sink = "0.3"
2222
keyed_priority_queue = "0.4"
23-
lazy_static = "1"
2423
lru = "0.12.0"
25-
mio = { version = "0.8.0", features = ["os-poll", "net"] }
26-
mysql_common = { version = "0.32", default-features = false }
27-
once_cell = "1.7.2"
24+
mysql_common = { version = "0.33", default-features = false }
2825
pem = "3.0"
2926
percent-encoding = "2.1.0"
3027
pin-project = "1.0.2"
3128
rand = "0.8.5"
3229
serde = "1"
3330
serde_json = "1"
3431
socket2 = "0.5.2"
35-
thiserror = "1.0.4"
32+
thiserror = "2"
3633
tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] }
3734
tokio-util = { version = "0.7.2", features = ["codec", "io"] }
3835
tracing = { version = "0.1.37", default-features = false, features = [
3936
"attributes",
4037
], optional = true }
41-
twox-hash = "1"
38+
twox-hash = "2"
4239
url = "2.1"
4340

4441
[dependencies.tokio-rustls]

src/conn/mod.rs

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct ConnInner {
102102
status: StatusFlags,
103103
last_ok_packet: Option<OkPacket<'static>>,
104104
last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105+
handshake_complete: bool,
105106
pool: Option<Pool>,
106107
pending_result: std::result::Result<Option<PendingResult>, ServerError>,
107108
tx_status: TxStatus,
@@ -147,6 +148,7 @@ impl ConnInner {
147148
status: StatusFlags::empty(),
148149
last_ok_packet: None,
149150
last_err_packet: None,
151+
handshake_complete: false,
150152
stream: None,
151153
is_mariadb: false,
152154
version: (0, 0, 0),
@@ -581,10 +583,11 @@ impl Conn {
581583
);
582584

583585
// Serialize here to satisfy borrow checker.
584-
let mut buf = crate::BUFFER_POOL.get();
586+
let mut buf = crate::buffer_pool().get();
585587
handshake_response.serialize(buf.as_mut());
586588

587589
self.write_packet(buf).await?;
590+
self.inner.handshake_complete = true;
588591
Ok(())
589592
}
590593

@@ -633,7 +636,7 @@ impl Conn {
633636
if let Some(plugin_data) = plugin_data {
634637
self.write_struct(&plugin_data.into_owned()).await?;
635638
} else {
636-
self.write_packet(crate::BUFFER_POOL.get()).await?;
639+
self.write_packet(crate::buffer_pool().get()).await?;
637640
}
638641

639642
self.continue_auth().await?;
@@ -701,7 +704,7 @@ impl Conn {
701704
}
702705
Some(0x04) => {
703706
let pass = self.inner.opts.pass().unwrap_or_default();
704-
let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
707+
let mut pass = crate::buffer_pool().get_with(pass.as_bytes());
705708
pass.as_mut().push(0);
706709

707710
if self.is_secure() || self.is_socket() {
@@ -789,7 +792,19 @@ impl Conn {
789792
if let Ok(ok_packet) = ok_packet {
790793
self.handle_ok(ok_packet.into_owned());
791794
} else {
792-
let err_packet = ParseBuf(packet).parse::<ErrPacket>(self.capabilities());
795+
// If we haven't completed the handshake the server will not be aware of our
796+
// capabilities and so it will behave as if we have none. In particular, the error
797+
// packet will not contain a SQL State field even if our capabilities do contain the
798+
// `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet
799+
// with no capability assumptions if we have not completed the handshake.
800+
//
801+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
802+
let capabilities = if self.inner.handshake_complete {
803+
self.capabilities()
804+
} else {
805+
CapabilityFlags::empty()
806+
};
807+
let err_packet = ParseBuf(packet).parse::<ErrPacket>(capabilities);
793808
if let Ok(err_packet) = err_packet {
794809
self.handle_err(err_packet)?;
795810
return Ok(true);
@@ -838,13 +853,13 @@ impl Conn {
838853

839854
/// Writes bytes to a server.
840855
pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
841-
let buf = crate::BUFFER_POOL.get_with(bytes);
856+
let buf = crate::buffer_pool().get_with(bytes);
842857
self.write_packet(buf).await
843858
}
844859

845860
/// Sends a serializable structure to a server.
846861
pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
847-
let mut buf = crate::BUFFER_POOL.get();
862+
let mut buf = crate::buffer_pool().get();
848863
x.serialize(buf.as_mut());
849864
self.write_packet(buf).await
850865
}
@@ -870,7 +885,7 @@ impl Conn {
870885
T: AsRef<[u8]>,
871886
{
872887
let cmd_data = cmd_data.as_ref();
873-
let mut buf = crate::BUFFER_POOL.get();
888+
let mut buf = crate::buffer_pool().get();
874889
let body = buf.as_mut();
875890
body.push(cmd as u8);
876891
body.extend_from_slice(cmd_data);
@@ -1270,10 +1285,11 @@ mod test {
12701285
use futures_util::stream::{self, StreamExt};
12711286
use mysql_common::constants::MAX_PAYLOAD_LEN;
12721287
use rand::Fill;
1288+
use tokio::{io::AsyncWriteExt, net::TcpListener};
12731289

12741290
use crate::{
12751291
from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1276-
OptsBuilder, Pool, Value, WhiteListFsHandler,
1292+
OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler,
12771293
};
12781294

12791295
#[tokio::test]
@@ -1400,16 +1416,18 @@ mod test {
14001416
.filter(|variant| plugins.iter().any(|p| p == variant.0));
14011417

14021418
for (plug, val, pass) in variants {
1419+
dbg!((plug, val, pass, conn.inner.version));
1420+
1421+
if plug == "mysql_native_password" && conn.inner.version >= (9, 0, 0) {
1422+
continue;
1423+
}
1424+
14031425
let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
14041426

14051427
let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
14061428
conn.query_drop(query).await.unwrap();
14071429

1408-
if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) {
1409-
conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1410-
.await
1411-
.unwrap();
1412-
} else {
1430+
if conn.inner.version <= (8, 0, 11) {
14131431
conn.query_drop(format!("SET old_passwords = {}", val))
14141432
.await
14151433
.unwrap();
@@ -1419,6 +1437,10 @@ mod test {
14191437
))
14201438
.await
14211439
.unwrap();
1440+
} else {
1441+
conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1442+
.await
1443+
.unwrap();
14221444
};
14231445

14241446
let opts = get_opts()
@@ -1548,6 +1570,10 @@ mod test {
15481570
};
15491571

15501572
for (i, plugin) in plugins.iter().enumerate() {
1573+
if *plugin == "mysql_native_password" && conn.server_version() >= (9, 0, 0) {
1574+
continue;
1575+
}
1576+
15511577
let mut rng = rand::thread_rng();
15521578
let mut pass = [0u8; 10];
15531579
pass.try_fill(&mut rng).unwrap();
@@ -2189,6 +2215,45 @@ mod test {
21892215
Ok(())
21902216
}
21912217

2218+
#[tokio::test]
2219+
async fn should_handle_initial_error_packet() {
2220+
let header = [
2221+
0x68, 0x00, 0x00, // packet_length
2222+
0x00, // sequence
2223+
0xff, // error_header
2224+
0x69, 0x04, // error_code
2225+
];
2226+
let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'";
2227+
2228+
// Create a fake MySQL server that immediately replies with an error packet.
2229+
let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap();
2230+
2231+
let listen_addr = listener.local_addr().unwrap();
2232+
2233+
tokio::task::spawn(async move {
2234+
let (mut stream, _) = listener.accept().await.unwrap();
2235+
stream.write_all(&header).await.unwrap();
2236+
stream.write_all(error_message.as_bytes()).await.unwrap();
2237+
stream.shutdown().await.unwrap();
2238+
});
2239+
2240+
let opts = OptsBuilder::default()
2241+
.ip_or_hostname(listen_addr.ip().to_string())
2242+
.tcp_port(listen_addr.port());
2243+
let server_err = match Conn::new(opts).await {
2244+
Err(Error::Server(server_err)) => server_err,
2245+
other => panic!("expected server error but got: {:?}", other),
2246+
};
2247+
assert_eq!(
2248+
server_err,
2249+
ServerError {
2250+
code: 1129,
2251+
state: "HY000".to_owned(),
2252+
message: error_message.to_owned(),
2253+
}
2254+
);
2255+
}
2256+
21922257
#[cfg(feature = "nightly")]
21932258
mod bench {
21942259
use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};

src/conn/stmt_cache.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// modified, or distributed except according to those terms.
88

99
use lru::LruCache;
10-
use twox_hash::XxHash;
10+
use twox_hash::XxHash64;
1111

1212
use std::{
1313
borrow::Borrow,
@@ -42,7 +42,7 @@ pub struct Entry {
4242
pub struct StmtCache {
4343
cap: usize,
4444
cache: LruCache<u32, Entry>,
45-
query_map: HashMap<QueryString, u32, BuildHasherDefault<XxHash>>,
45+
query_map: HashMap<QueryString, u32, BuildHasherDefault<XxHash64>>,
4646
}
4747

4848
impl StmtCache {

src/io/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ impl Default for PacketCodec {
7575
fn default() -> Self {
7676
Self {
7777
inner: Default::default(),
78-
decode_buf: crate::BUFFER_POOL.get(),
78+
decode_buf: crate::buffer_pool().get(),
7979
}
8080
}
8181
}
@@ -100,7 +100,7 @@ impl Decoder for PacketCodec {
100100

101101
fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, IoError> {
102102
if self.inner.decode(src, self.decode_buf.as_mut())? {
103-
let new_buf = crate::BUFFER_POOL.get();
103+
let new_buf = crate::buffer_pool().get();
104104
Ok(Some(replace(&mut self.decode_buf, new_buf)))
105105
} else {
106106
Ok(None)

src/lib.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,11 @@ mod queryable;
453453

454454
type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result<T>>;
455455

456-
static BUFFER_POOL: once_cell::sync::Lazy<Arc<crate::buffer_pool::BufferPool>> =
457-
once_cell::sync::Lazy::new(Default::default);
456+
fn buffer_pool() -> &'static Arc<crate::buffer_pool::BufferPool> {
457+
static BUFFER_POOL: std::sync::OnceLock<Arc<crate::buffer_pool::BufferPool>> =
458+
std::sync::OnceLock::new();
459+
BUFFER_POOL.get_or_init(Default::default)
460+
}
458461

459462
#[cfg(feature = "binlog")]
460463
#[doc(inline)]
@@ -562,6 +565,8 @@ pub mod prelude {
562565
#[doc(inline)]
563566
pub use crate::queryable::Queryable;
564567
#[doc(inline)]
568+
pub use mysql_common::prelude::ColumnIndex;
569+
#[doc(inline)]
565570
pub use mysql_common::prelude::FromRow;
566571
#[doc(inline)]
567572
pub use mysql_common::prelude::{FromValue, ToValue};
@@ -608,9 +613,8 @@ pub mod prelude {
608613

609614
#[doc(hidden)]
610615
pub mod test_misc {
611-
use lazy_static::lazy_static;
612-
613616
use std::env;
617+
use std::sync::OnceLock;
614618

615619
use crate::opts::{Opts, OptsBuilder, SslOpts};
616620

@@ -621,26 +625,17 @@ pub mod test_misc {
621625
_dummy(err);
622626
}
623627

624-
lazy_static! {
625-
pub static ref DATABASE_URL: String = {
628+
pub fn get_opts() -> OptsBuilder {
629+
static DATABASE_OPTS: OnceLock<Opts> = OnceLock::new();
630+
let database_opts = DATABASE_OPTS.get_or_init(|| {
626631
if let Ok(url) = env::var("DATABASE_URL") {
627-
let opts = Opts::from_url(&url).expect("DATABASE_URL invalid");
628-
if opts
629-
.db_name()
630-
.expect("a database name is required")
631-
.is_empty()
632-
{
633-
panic!("database name is empty");
634-
}
635-
url
632+
Opts::from_url(&url).expect("DATABASE_URL invalid")
636633
} else {
637-
"mysql://root:password@localhost:3307/mysql".into()
634+
Opts::from_url("mysql://root:password@localhost:3307/mysql").unwrap()
638635
}
639-
};
640-
}
636+
});
641637

642-
pub fn get_opts() -> OptsBuilder {
643-
let mut builder = OptsBuilder::from_opts(Opts::from_url(&DATABASE_URL).unwrap());
638+
let mut builder = OptsBuilder::from_opts(database_opts.clone());
644639
if test_ssl() {
645640
let ssl_opts = SslOpts::default()
646641
.with_danger_skip_domain_validation(true)

0 commit comments

Comments
 (0)