Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ serde = "1"
serde_json = "1"
socket2 = "0.5.2"
thiserror = "2"
tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] }
tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt", "sync"] }
tokio-util = { version = "0.7.2", features = ["codec", "io"] }
tracing = { version = "0.1.37", default-features = false, features = [
"attributes",
Expand Down
8 changes: 6 additions & 2 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,12 +549,16 @@ impl Conn {
);
self.write_struct(&ssl_request).await?;
let conn = self;
let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
let ssl_opts = conn.opts().ssl_opts_and_connector().expect("unreachable");
let domain = ssl_opts
.ssl_opts()
.tls_hostname_override()
.unwrap_or_else(|| conn.opts().ip_or_hostname())
.into();
conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
let tls_connector = ssl_opts.build_tls_connector().await?;
conn.stream_mut()?
.make_secure(domain, &tls_connector)
.await?;
Ok(())
} else {
Ok(())
Expand Down
20 changes: 5 additions & 15 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ use std::{
use crate::{
buffer_pool::PooledBuf,
error::IoError,
opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT},
opts::{HostPortOrUrl, DEFAULT_PORT},
};

#[cfg(unix)]
use crate::io::socket::Socket;

mod tls;

pub(crate) use self::tls::TlsConnector;

macro_rules! with_interrupted {
($e:expr) => {
loop {
Expand Down Expand Up @@ -193,18 +195,6 @@ impl Endpoint {
matches!(self, Endpoint::Secure(_))
}

#[cfg(all(not(feature = "native-tls-tls"), not(feature = "rustls")))]
pub async fn make_secure(
&mut self,
_domain: String,
_ssl_opts: crate::SslOpts,
) -> crate::error::Result<()> {
panic!(
"Client had asked for TLS connection but TLS support is disabled. \
Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]"
)
}

pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
match *self {
Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?,
Expand Down Expand Up @@ -415,11 +405,11 @@ impl Stream {
pub(crate) async fn make_secure(
&mut self,
domain: String,
ssl_opts: SslOpts,
tls_connector: &TlsConnector,
) -> crate::error::Result<()> {
let codec = self.codec.take().unwrap();
let FramedParts { mut io, codec, .. } = codec.into_parts();
io.make_secure(domain, ssl_opts).await?;
io.make_secure(domain, tls_connector).await?;
let codec = Framed::new(io, codec);
self.codec = Some(Box::new(codec));
Ok(())
Expand Down
13 changes: 11 additions & 2 deletions src/io/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
#![cfg(any(feature = "native-tls-tls", feature = "rustls"))]

#[cfg(feature = "native-tls-tls")]
mod native_tls_io;
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls-tls")))]
mod no_tls;
#[cfg(feature = "rustls-tls")]
mod rustls_io;

#[cfg(feature = "native-tls-tls")]
pub(crate) use self::native_tls_io::TlsConnector;
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls-tls")))]
pub(crate) use self::no_tls::TlsConnector;
#[cfg(feature = "rustls-tls")]
pub(crate) use self::rustls_io::TlsConnector;
41 changes: 24 additions & 17 deletions src/io/tls/native_tls_io.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#![cfg(feature = "native-tls-tls")]

use native_tls::{Certificate, TlsConnector};
use tokio_native_tls::native_tls::{self, Certificate};

use crate::io::Endpoint;
use crate::{Result, SslOpts};

pub use tokio_native_tls::TlsConnector;

impl SslOpts {
async fn load_root_certs(&self) -> crate::Result<Vec<Certificate>> {
let mut output = Vec::new();
Expand All @@ -16,29 +16,36 @@ impl SslOpts {

Ok(output)
}

pub(crate) async fn build_tls_connector(&self) -> Result<TlsConnector> {
let mut builder = native_tls::TlsConnector::builder();
for root_cert in self.load_root_certs().await? {
builder.add_root_certificate(root_cert);
}

if let Some(client_identity) = self.client_identity() {
builder.identity(client_identity.load().await?);
}
builder.danger_accept_invalid_hostnames(self.skip_domain_validation());
builder.danger_accept_invalid_certs(self.accept_invalid_certs());
builder.disable_built_in_roots(self.disable_built_in_roots());
let tls_connector: TlsConnector = builder.build()?.into();
Ok(tls_connector)
}
}

impl Endpoint {
pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> {
pub async fn make_secure(
&mut self,
domain: String,
tls_connector: &TlsConnector,
) -> Result<()> {
#[cfg(unix)]
if self.is_socket() {
// won't secure socket connection
return Ok(());
}

let mut builder = TlsConnector::builder();
for root_cert in ssl_opts.load_root_certs().await? {
builder.add_root_certificate(root_cert);
}

if let Some(client_identity) = ssl_opts.client_identity() {
builder.identity(client_identity.load().await?);
}
builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation());
builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs());
builder.disable_built_in_roots(ssl_opts.disable_built_in_roots());
let tls_connector: tokio_native_tls::TlsConnector = builder.build()?.into();

*self = match self {
Endpoint::Plain(ref mut stream) => {
let stream = stream.take().unwrap();
Expand Down
24 changes: 24 additions & 0 deletions src/io/tls/no_tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use crate::io::Endpoint;
use crate::{Result, SslOpts};

#[derive(Clone, Debug)]
pub(crate) struct TlsConnector;

impl SslOpts {
pub(crate) async fn build_tls_connector(&self) -> Result<TlsConnector> {
panic!(
"Client had asked for TLS connection but TLS support is disabled. \
Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]"
)
}
}

impl Endpoint {
pub async fn make_secure(
&mut self,
_domain: String,
_tls_connector: &TlsConnector,
) -> Result<()> {
unreachable!();
}
}
48 changes: 26 additions & 22 deletions src/io/tls/rustls_io.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![cfg(feature = "rustls-tls")]

use std::sync::Arc;

use rustls::{
Expand All @@ -12,7 +10,7 @@ use rustls::{
};

use rustls_pemfile::certs;
use tokio_rustls::TlsConnector;
pub(crate) use tokio_rustls::TlsConnector;

use crate::{io::Endpoint, Result, SslOpts, TlsError};

Expand All @@ -35,54 +33,60 @@ impl SslOpts {

Ok(output)
}
}

impl Endpoint {
pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> {
#[cfg(unix)]
if self.is_socket() {
// won't secure socket connection
return Ok(());
}

pub(crate) async fn build_tls_connector(&self) -> Result<TlsConnector> {
let mut root_store = RootCertStore::empty();
if !ssl_opts.disable_built_in_roots() {
if !self.disable_built_in_roots() {
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned()));
}

for cert in ssl_opts.load_root_certs().await? {
for cert in self.load_root_certs().await? {
root_store.add(cert)?;
}

let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone());

let mut config = if let Some(identity) = ssl_opts.client_identity() {
let mut config = if let Some(identity) = self.client_identity() {
let (cert_chain, priv_key) = identity.load().await?;
config_builder.with_client_auth_cert(cert_chain, priv_key)?
} else {
config_builder.with_no_client_auth()
};

let server_name = ServerName::try_from(domain.as_str())
.map_err(|_| webpki::InvalidDnsNameError)?
.to_owned();
let mut dangerous = config.dangerous();
let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store))
.build()
.map_err(TlsError::from)?;
let dangerous_verifier = DangerousVerifier::new(
ssl_opts.accept_invalid_certs(),
ssl_opts.skip_domain_validation(),
self.accept_invalid_certs(),
self.skip_domain_validation(),
web_pki_verifier,
);
dangerous.set_certificate_verifier(Arc::new(dangerous_verifier));
let client_config = Arc::new(config);
Ok(TlsConnector::from(client_config))
}
}

impl Endpoint {
pub async fn make_secure(
&mut self,
domain: String,
tls_connector: &TlsConnector,
) -> Result<()> {
#[cfg(unix)]
if self.is_socket() {
// won't secure socket connection
return Ok(());
}

*self = match self {
Endpoint::Plain(ref mut stream) => {
let stream = stream.take().unwrap();

let client_config = Arc::new(config);
let tls_connector = TlsConnector::from(client_config);
let server_name = ServerName::try_from(domain.as_str())
.map_err(|_| webpki::InvalidDnsNameError)?
.to_owned();
let connection = tls_connector.connect(server_name, stream).await?;

Endpoint::Secure(connection)
Expand Down
Loading