Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flipt-client-js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@
"type": "git",
"url": "https://github.com/flipt-io/flipt-client-sdks/tree/main/flipt-client-js"
}
}
}
250 changes: 250 additions & 0 deletions flipt-engine-ffi/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ use tokio_util::io::StreamReader;
use fliptevaluation::error::Error;
use fliptevaluation::models::source;

use crate::TlsConfig;
use base64::prelude::BASE64_STANDARD;
use base64::Engine as Base64Engine;

#[derive(Debug, Clone, Default, Deserialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(rename_all = "snake_case")]
Expand Down Expand Up @@ -111,6 +115,7 @@ pub struct HTTPFetcherBuilder {
request_timeout: Option<Duration>,
update_interval: Duration,
mode: FetchMode,
tls_config: Option<TlsConfig>,
}

#[derive(Deserialize)]
Expand All @@ -136,6 +141,7 @@ impl HTTPFetcherBuilder {
request_timeout: None,
update_interval: Duration::from_secs(120),
mode: FetchMode::default(),
tls_config: None,
}
}

Expand Down Expand Up @@ -174,6 +180,11 @@ impl HTTPFetcherBuilder {
self
}

pub fn tls_config(mut self, tls_config: TlsConfig) -> Self {
self.tls_config = Some(tls_config);
self
}

pub fn build(self) -> Result<HTTPFetcher, Error> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(Duration::from_secs(1), Duration::from_secs(30))
Expand Down Expand Up @@ -205,6 +216,11 @@ impl HTTPFetcherBuilder {
}
}

// Apply TLS configuration if provided
if let Some(tls_config) = &self.tls_config {
client_builder = configure_tls(client_builder, tls_config)?;
}

let client = client_builder
.build()
.map_err(|e| Error::Internal(format!("failed to create client: {e}")))?;
Expand Down Expand Up @@ -476,14 +492,75 @@ impl HTTPFetcher {
}
}

fn configure_tls(
mut builder: reqwest::ClientBuilder,
tls_config: &TlsConfig,
) -> Result<reqwest::ClientBuilder, Error> {
// Handle insecure mode
if tls_config.insecure_skip_verify.unwrap_or(false) {
builder = builder.danger_accept_invalid_certs(true);
}

// Handle custom CA certificates
if let Some(ca_cert_data) = &tls_config.ca_cert_data {
let cert_bytes = BASE64_STANDARD
.decode(ca_cert_data)
.map_err(|e| Error::Internal(format!("Invalid CA cert data: {e}")))?;
let cert = reqwest::Certificate::from_pem(&cert_bytes)
.map_err(|e| Error::Internal(format!("Invalid CA certificate: {e}")))?;
builder = builder.add_root_certificate(cert);
} else if let Some(ca_cert_file) = &tls_config.ca_cert_file {
let cert_bytes = std::fs::read(ca_cert_file)
.map_err(|e| Error::Internal(format!("Failed to read CA cert file: {e}")))?;
let cert = reqwest::Certificate::from_pem(&cert_bytes)
.map_err(|e| Error::Internal(format!("Invalid CA certificate file: {e}")))?;
builder = builder.add_root_certificate(cert);
}

// Handle client certificates for mutual TLS
if let (Some(cert_data), Some(key_data)) =
(&tls_config.client_cert_data, &tls_config.client_key_data)
{
let cert_bytes = BASE64_STANDARD
.decode(cert_data)
.map_err(|e| Error::Internal(format!("Invalid client cert data: {e}")))?;
let key_bytes = BASE64_STANDARD
.decode(key_data)
.map_err(|e| Error::Internal(format!("Invalid client key data: {e}")))?;
let mut combined = cert_bytes.clone();
combined.extend_from_slice(&key_bytes);
let identity = reqwest::Identity::from_pem(&combined)
.map_err(|e| Error::Internal(format!("Invalid client certificate: {e}")))?;
builder = builder.identity(identity);
} else if let (Some(cert_file), Some(key_file)) =
(&tls_config.client_cert_file, &tls_config.client_key_file)
{
let cert_bytes = std::fs::read(cert_file)
.map_err(|e| Error::Internal(format!("Failed to read client cert file: {e}")))?;
let key_bytes = std::fs::read(key_file)
.map_err(|e| Error::Internal(format!("Failed to read client key file: {e}")))?;
let mut combined = cert_bytes.clone();
combined.extend_from_slice(&key_bytes);
let identity = reqwest::Identity::from_pem(&combined)
.map_err(|e| Error::Internal(format!("Invalid client certificate files: {e}")))?;
builder = builder.identity(identity);
}

Ok(builder)
}

#[cfg(test)]
mod tests {
use futures::FutureExt;
use mockito::Server;

use crate::http::configure_tls;
use crate::http::Authentication;
use crate::http::FetchMode;
use crate::http::HTTPFetcherBuilder;
use crate::TlsConfig;
use base64::prelude::BASE64_STANDARD;
use base64::Engine as Base64Engine;
use tokio::sync::mpsc;

#[tokio::test]
Expand Down Expand Up @@ -774,4 +851,177 @@ mod tests {

assert_eq!(unwrapped_string, Authentication::JwtToken("secret".into()));
}

#[test]
fn test_tls_config_insecure_skip_verify() {
let tls_config = TlsConfig {
insecure_skip_verify: Some(true),
ca_cert_file: None,
ca_cert_data: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}

#[test]
fn test_tls_config_custom_ca_cert_data() {
// Use the existing localhost.crt for testing
let cert_pem = include_str!("testdata/localhost.crt");
let cert_b64 = BASE64_STANDARD.encode(cert_pem);

let tls_config = TlsConfig {
ca_cert_data: Some(cert_b64),
insecure_skip_verify: None,
ca_cert_file: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}

#[test]
fn test_tls_config_custom_ca_cert_file() {
let tls_config = TlsConfig {
ca_cert_file: Some("src/testdata/localhost.crt".to_string()),
insecure_skip_verify: None,
ca_cert_data: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}

#[test]
fn test_tls_config_client_certificates_data() {
let cert_pem = include_str!("testdata/localhost.crt");
let key_pem = include_str!("testdata/localhost.key");
let cert_b64 = BASE64_STANDARD.encode(cert_pem);
let key_b64 = BASE64_STANDARD.encode(key_pem);

let tls_config = TlsConfig {
client_cert_data: Some(cert_b64),
client_key_data: Some(key_b64),
insecure_skip_verify: None,
ca_cert_file: None,
ca_cert_data: None,
client_cert_file: None,
client_key_file: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}

#[test]
fn test_tls_config_client_certificates_files() {
let tls_config = TlsConfig {
client_cert_file: Some("src/testdata/localhost.crt".to_string()),
client_key_file: Some("src/testdata/localhost.key".to_string()),
insecure_skip_verify: None,
ca_cert_file: None,
ca_cert_data: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}

#[test]
fn test_tls_config_invalid_ca_cert_data() {
let tls_config = TlsConfig {
ca_cert_data: Some("invalid_base64".to_string()),
insecure_skip_verify: None,
ca_cert_file: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid CA cert data"));
}

#[test]
fn test_tls_config_invalid_ca_cert_file() {
let tls_config = TlsConfig {
ca_cert_file: Some("nonexistent.crt".to_string()),
insecure_skip_verify: None,
ca_cert_data: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Failed to read CA cert file"));
}

#[test]
fn test_tls_config_combined_options() {
let cert_pem = include_str!("testdata/localhost.crt");
let cert_b64 = BASE64_STANDARD.encode(cert_pem);

let tls_config = TlsConfig {
ca_cert_data: Some(cert_b64),
insecure_skip_verify: Some(true),
ca_cert_file: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}

#[test]
fn test_tls_config_empty() {
let tls_config = TlsConfig {
insecure_skip_verify: None,
ca_cert_file: None,
ca_cert_data: None,
client_cert_file: None,
client_key_file: None,
client_cert_data: None,
client_key_data: None,
};

let builder = reqwest::Client::builder();
let result = configure_tls(builder, &tls_config);
assert!(result.is_ok());
}
}
Loading
Loading