Skip to content

Commit 2b8f082

Browse files
committed
opts: Add ssl-related query parameters
1 parent 991069a commit 2b8f082

File tree

1 file changed

+79
-1
lines changed

1 file changed

+79
-1
lines changed

src/opts/mod.rs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,10 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result<MysqlOpts, UrlError> {
10781078
let (mut opts, query_pairs): (MysqlOpts, _) = from_url_basic(url)?;
10791079
let mut pool_min = DEFAULT_POOL_CONSTRAINTS.min;
10801080
let mut pool_max = DEFAULT_POOL_CONSTRAINTS.max;
1081+
1082+
let mut skip_domain_validation = false;
1083+
let mut accept_invalid_certs = false;
1084+
10811085
for (key, value) in query_pairs {
10821086
if key == "pool_min" {
10831087
match usize::from_str(&*value) {
@@ -1240,6 +1244,40 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result<MysqlOpts, UrlError> {
12401244
value,
12411245
});
12421246
}
1247+
} else if key == "require_ssl" {
1248+
match bool::from_str(&*value) {
1249+
Ok(x) => opts.ssl_opts = x.then(SslOpts::default),
1250+
_ => {
1251+
return Err(UrlError::InvalidParamValue {
1252+
param: "require_ssl".into(),
1253+
value,
1254+
});
1255+
}
1256+
}
1257+
} else if key == "verify_ca" {
1258+
match bool::from_str(&*value) {
1259+
Ok(x) => {
1260+
accept_invalid_certs = !x;
1261+
}
1262+
_ => {
1263+
return Err(UrlError::InvalidParamValue {
1264+
param: "verify_ca".into(),
1265+
value,
1266+
});
1267+
}
1268+
}
1269+
} else if key == "verify_identity" {
1270+
match bool::from_str(&*value) {
1271+
Ok(x) => {
1272+
skip_domain_validation = !x;
1273+
}
1274+
_ => {
1275+
return Err(UrlError::InvalidParamValue {
1276+
param: "verify_identity".into(),
1277+
value,
1278+
});
1279+
}
1280+
}
12431281
} else {
12441282
return Err(UrlError::UnknownParameter { param: key });
12451283
}
@@ -1254,6 +1292,11 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result<MysqlOpts, UrlError> {
12541292
});
12551293
}
12561294

1295+
if let Some(ref mut ssl_opts) = opts.ssl_opts.as_mut() {
1296+
ssl_opts.accept_invalid_certs = accept_invalid_certs;
1297+
ssl_opts.skip_domain_validation = skip_domain_validation;
1298+
}
1299+
12571300
Ok(opts)
12581301
}
12591302

@@ -1276,7 +1319,7 @@ impl<'a> TryFrom<&'a str> for Opts {
12761319
#[cfg(test)]
12771320
mod test {
12781321
use super::{HostPortOrUrl, MysqlOpts, Opts, Url};
1279-
use crate::error::UrlError::InvalidParamValue;
1322+
use crate::{error::UrlError::InvalidParamValue, SslOpts};
12801323

12811324
use std::str::FromStr;
12821325

@@ -1345,6 +1388,41 @@ mod test {
13451388
assert_eq!(opts.ip_or_hostname(), "[::1]");
13461389
}
13471390

1391+
#[test]
1392+
fn should_parse_ssl_params() {
1393+
const URL1: &str = "mysql://localhost/foo?require_ssl=false";
1394+
let opts = Opts::from_url(URL1).unwrap();
1395+
assert_eq!(opts.ssl_opts(), None);
1396+
1397+
const URL2: &str = "mysql://localhost/foo?require_ssl=true";
1398+
let opts = Opts::from_url(URL2).unwrap();
1399+
assert_eq!(opts.ssl_opts(), Some(&SslOpts::default()));
1400+
1401+
const URL3: &str = "mysql://localhost/foo?require_ssl=true&verify_ca=false";
1402+
let opts = Opts::from_url(URL3).unwrap();
1403+
assert_eq!(
1404+
opts.ssl_opts(),
1405+
Some(&SslOpts::default().with_danger_accept_invalid_certs(true))
1406+
);
1407+
1408+
const URL4: &str =
1409+
"mysql://localhost/foo?require_ssl=true&verify_ca=false&verify_identity=false";
1410+
let opts = Opts::from_url(URL4).unwrap();
1411+
assert_eq!(
1412+
opts.ssl_opts(),
1413+
Some(
1414+
&SslOpts::default()
1415+
.with_danger_accept_invalid_certs(true)
1416+
.with_danger_skip_domain_validation(true)
1417+
)
1418+
);
1419+
1420+
const URL5: &str =
1421+
"mysql://localhost/foo?require_ssl=false&verify_ca=false&verify_identity=false";
1422+
let opts = Opts::from_url(URL5).unwrap();
1423+
assert_eq!(opts.ssl_opts(), None);
1424+
}
1425+
13481426
#[test]
13491427
#[should_panic]
13501428
fn should_panic_on_invalid_url() {

0 commit comments

Comments
 (0)