Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions api/crypto/frame_crypto_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const EVP_CIPHER* GetAesCbcAlgorithmFromKeySize(size_t key_size_bytes) {
}

inline bool FrameIsH264(webrtc::TransformableFrameInterface* frame,
webrtc::FrameCryptorTransformer::MediaType type) {
webrtc::FrameCryptorTransformer::MediaType type) {
switch (type) {
case webrtc::FrameCryptorTransformer::MediaType::kVideoFrame: {
auto videoFrame =
Expand Down Expand Up @@ -314,11 +314,18 @@ FrameCryptorTransformer::FrameCryptorTransformer(
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider)
: signaling_thread_(signaling_thread),
thread_(rtc::Thread::Create()),
participant_id_(participant_id),
type_(type),
algorithm_(algorithm),
key_provider_(key_provider) {
RTC_DCHECK(key_provider_ != nullptr);
thread_->SetName("FrameCryptorTransformer", this);
thread_->Start();
}

FrameCryptorTransformer::~FrameCryptorTransformer() {
thread_->Stop();
}

void FrameCryptorTransformer::Transform(
Expand All @@ -333,10 +340,16 @@ void FrameCryptorTransformer::Transform(
// do encrypt or decrypt here...
switch (frame->GetDirection()) {
case webrtc::TransformableFrameInterface::Direction::kSender:
encryptFrame(std::move(frame));
RTC_DCHECK(thread_ != nullptr);
thread_->PostTask([frame = std::move(frame), this]() mutable {
encryptFrame(std::move(frame));
});
break;
case webrtc::TransformableFrameInterface::Direction::kReceiver:
decryptFrame(std::move(frame));
RTC_DCHECK(thread_ != nullptr);
thread_->PostTask([frame = std::move(frame), this]() mutable {
decryptFrame(std::move(frame));
});
break;
case webrtc::TransformableFrameInterface::Direction::kUnknown:
// do nothing
Expand Down Expand Up @@ -371,6 +384,8 @@ void FrameCryptorTransformer::encryptFrame(

rtc::ArrayView<const uint8_t> date_in = frame->GetData();
if (date_in.size() == 0 || !enabled_cryption) {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::encryptFrame() "
"date_in.size() == 0 || enabled_cryption == false";
sink_callback->OnTransformedFrame(std::move(frame));
return;
}
Expand Down Expand Up @@ -425,7 +440,8 @@ void FrameCryptorTransformer::encryptFrame(
data_out.AppendData(frame_header);

if (FrameIsH264(frame.get(), type_)) {
H264::WriteRbsp(data_without_header.data(),data_without_header.size(), &data_out);
H264::WriteRbsp(data_without_header.data(), data_without_header.size(),
&data_out);
} else {
data_out.AppendData(data_without_header);
RTC_CHECK_EQ(data_out.size(), frame_header.size() +
Expand Down Expand Up @@ -490,34 +506,31 @@ void FrameCryptorTransformer::decryptFrame(
rtc::ArrayView<const uint8_t> date_in = frame->GetData();

if (date_in.size() == 0 || !enabled_cryption) {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() "
"date_in.size() == 0 || enabled_cryption == false";
sink_callback->OnTransformedFrame(std::move(frame));
return;
}

auto uncrypted_magic_bytes = key_provider_->options().uncrypted_magic_bytes;
if (uncrypted_magic_bytes.size() > 0 &&
date_in.size() >= uncrypted_magic_bytes.size() + 1) {
auto tmp =
date_in.subview(date_in.size() - (uncrypted_magic_bytes.size() + 1),
uncrypted_magic_bytes.size());

if (uncrypted_magic_bytes == std::vector<uint8_t>(tmp.begin(), tmp.end())) {
date_in.size() >= uncrypted_magic_bytes.size()) {
auto tmp = date_in.subview(date_in.size() - (uncrypted_magic_bytes.size()),
uncrypted_magic_bytes.size());
auto data = std::vector<uint8_t>(tmp.begin(), tmp.end());
if (uncrypted_magic_bytes == data) {
RTC_CHECK_EQ(tmp.size(), uncrypted_magic_bytes.size());
auto frame_type = date_in.subview(date_in.size() - 1, 1);
RTC_CHECK_EQ(frame_type.size(), 1);

RTC_LOG(LS_INFO)
<< "FrameCryptorTransformer::uncrypted_magic_bytes( type "
<< frame_type[0] << ", tmp " << to_hex(tmp.data(), tmp.size())
<< ", magic bytes "
<< to_hex(uncrypted_magic_bytes.data(), uncrypted_magic_bytes.size())
<< ")";
RTC_LOG(LS_INFO) << "FrameCryptorTransformer::uncrypted_magic_bytes( tmp "
<< to_hex(tmp.data(), tmp.size()) << ", magic bytes "
<< to_hex(uncrypted_magic_bytes.data(),
uncrypted_magic_bytes.size())
<< ")";

// magic bytes detected, this is a non-encrypted frame, skip frame
// decryption.
rtc::Buffer data_out;
data_out.AppendData(date_in.subview(
0, date_in.size() - uncrypted_magic_bytes.size() - 1));
data_out.AppendData(
date_in.subview(0, date_in.size() - uncrypted_magic_bytes.size()));
frame->SetData(data_out);
sink_callback->OnTransformedFrame(std::move(frame));
return;
Expand All @@ -539,8 +552,8 @@ void FrameCryptorTransformer::decryptFrame(

if (ivLength != getIvSize()) {
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() ivLength["
<< static_cast<int>(ivLength) << "] != getIvSize()["
<< static_cast<int>(getIvSize()) << "]";
<< static_cast<int>(ivLength) << "] != getIvSize()["
<< static_cast<int>(getIvSize()) << "]";
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
onFrameCryptionStateChanged(last_dec_error_);
Expand Down Expand Up @@ -585,7 +598,8 @@ void FrameCryptorTransformer::decryptFrame(

if (FrameIsH264(frame.get(), type_) &&
NeedsRbspUnescaping(encrypted_buffer.data(), encrypted_buffer.size())) {
encrypted_buffer.SetData(H264::ParseRbsp(encrypted_buffer.data(), encrypted_buffer.size()));
encrypted_buffer.SetData(
H264::ParseRbsp(encrypted_buffer.data(), encrypted_buffer.size()));
}

rtc::Buffer encrypted_payload(encrypted_buffer.size() - ivLength - 2);
Expand Down Expand Up @@ -665,10 +679,11 @@ void FrameCryptorTransformer::decryptFrame(
}

if (!decryption_success) {
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
key_handler->DecryptionFailure();
onFrameCryptionStateChanged(last_dec_error_);
if (key_handler->DecryptionFailure()) {
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
onFrameCryptionStateChanged(last_dec_error_);
}
}
return;
}
Expand All @@ -686,15 +701,15 @@ void FrameCryptorTransformer::decryptFrame(
sink_callback->OnTransformedFrame(std::move(frame));
}

void FrameCryptorTransformer::onFrameCryptionStateChanged(FrameCryptionState state) {
void FrameCryptorTransformer::onFrameCryptionStateChanged(
FrameCryptionState state) {
webrtc::MutexLock lock(&mutex_);
if(observer_) {
if (observer_) {
RTC_DCHECK(signaling_thread_ != nullptr);
signaling_thread_->PostTask(
[observer = observer_, state = state, participant_id = participant_id_]() mutable {
observer->OnFrameCryptionStateChanged(participant_id, state);
}
);
signaling_thread_->PostTask([observer = observer_, state = state,
participant_id = participant_id_]() mutable {
observer->OnFrameCryptionStateChanged(participant_id, state);
});
}
}

Expand Down
79 changes: 46 additions & 33 deletions api/crypto/frame_crypto_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ struct KeyProviderOptions {
std::vector<uint8_t> uncrypted_magic_bytes;
int ratchet_window_size;
int failure_tolerance;
KeyProviderOptions() : shared_key(false), ratchet_window_size(0), failure_tolerance(-1) {}
KeyProviderOptions()
: shared_key(false), ratchet_window_size(0), failure_tolerance(-1) {}
KeyProviderOptions(KeyProviderOptions& copy)
: shared_key(copy.shared_key),
ratchet_salt(copy.ratchet_salt),
Expand All @@ -55,10 +56,10 @@ struct KeyProviderOptions {

class KeyProvider : public rtc::RefCountInterface {
public:

virtual bool SetSharedKey(int key_index, std::vector<uint8_t> key) = 0;

virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) = 0;
virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(
const std::string participant_id) = 0;

virtual const std::vector<uint8_t> RatchetSharedKey(int key_index) = 0;

Expand Down Expand Up @@ -94,8 +95,10 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
KeySet(std::vector<uint8_t> material, std::vector<uint8_t> encryptionKey)
: material(material), encryption_key(encryptionKey) {}
};

public:
ParticipantKeyHandler(KeyProvider* key_provider) : key_provider_(key_provider) {
ParticipantKeyHandler(KeyProvider* key_provider)
: key_provider_(key_provider) {
crypto_key_ring_.resize(KEYRING_SIZE);
}

Expand All @@ -116,7 +119,8 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
}
auto current_material = key_set->material;
std::vector<uint8_t> new_material;
if (DerivePBKDF2KeyFromRawKey(current_material, key_provider_->options().ratchet_salt, 256,
if (DerivePBKDF2KeyFromRawKey(current_material,
key_provider_->options().ratchet_salt, 256,
&new_material) != 0) {
return std::vector<uint8_t>();
}
Expand All @@ -139,16 +143,17 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
std::vector<uint8_t> RatchetKeyMaterial(
std::vector<uint8_t> current_material) {
std::vector<uint8_t> new_material;
if (DerivePBKDF2KeyFromRawKey(current_material, key_provider_->options().ratchet_salt, 256,
if (DerivePBKDF2KeyFromRawKey(current_material,
key_provider_->options().ratchet_salt, 256,
&new_material) != 0) {
return std::vector<uint8_t>();
}
return new_material;
}

rtc::scoped_refptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> ratchet_salt,
unsigned int optional_length_bits) {
std::vector<uint8_t> derived_key;
if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits,
&derived_key) == 0) {
Expand Down Expand Up @@ -177,16 +182,19 @@ class ParticipantKeyHandler : public rtc::RefCountInterface {
DeriveKeys(password, key_provider_->options().ratchet_salt, 128);
}

void DecryptionFailure() {
bool DecryptionFailure() {
webrtc::MutexLock lock(&mutex_);
if (key_provider_->options().failure_tolerance < 0) {
return;
return false;
}
decryption_failure_count_ += 1;

if (decryption_failure_count_ > key_provider_->options().failure_tolerance) {
if (decryption_failure_count_ >
key_provider_->options().failure_tolerance) {
has_valid_key_ = false;
return true;
}
return false;
}

private:
Expand All @@ -206,16 +214,16 @@ class DefaultKeyProviderImpl : public KeyProvider {
/// Set the shared key.
bool SetSharedKey(int key_index, std::vector<uint8_t> key) override {
webrtc::MutexLock lock(&mutex_);
if(options_.shared_key) {
if (options_.shared_key) {
if (keys_.find("shared") == keys_.end()) {
keys_["shared"] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
}

auto key_handler = keys_["shared"];
key_handler->SetKey(key, key_index);

for(auto& key_pair : keys_) {
if(key_pair.first != "shared") {
for (auto& key_pair : keys_) {
if (key_pair.first != "shared") {
key_pair.second->SetKey(key, key_index);
}
}
Expand All @@ -227,13 +235,13 @@ class DefaultKeyProviderImpl : public KeyProvider {
const std::vector<uint8_t> RatchetSharedKey(int key_index) override {
webrtc::MutexLock lock(&mutex_);
auto it = keys_.find("shared");
if(it == keys_.end()) {
if (it == keys_.end()) {
return std::vector<uint8_t>();
}
auto new_key = it->second->RatchetKey(key_index);
if(options_.shared_key) {
for(auto& key_pair : keys_) {
if(key_pair.first != "shared") {
if (options_.shared_key) {
for (auto& key_pair : keys_) {
if (key_pair.first != "shared") {
key_pair.second->SetKey(new_key, key_index);
}
}
Expand All @@ -244,19 +252,20 @@ class DefaultKeyProviderImpl : public KeyProvider {
const std::vector<uint8_t> ExportSharedKey(int key_index) const override {
webrtc::MutexLock lock(&mutex_);
auto it = keys_.find("shared");
if(it == keys_.end()) {
if (it == keys_.end()) {
return std::vector<uint8_t>();
}
auto key_set = it->second->GetKeySet(key_index);
if(key_set) {
if (key_set) {
return key_set->material;
}
return std::vector<uint8_t>();
}

const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) override {
const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(
const std::string participant_id) override {
webrtc::MutexLock lock(&mutex_);
if(options_.shared_key && keys_.find("shared") != keys_.end()) {
if (options_.shared_key && keys_.find("shared") != keys_.end()) {
auto shared_key_handler = keys_["shared"];
if (keys_.find(participant_id) != keys_.end()) {
return keys_[participant_id];
Expand All @@ -276,7 +285,8 @@ class DefaultKeyProviderImpl : public KeyProvider {
webrtc::MutexLock lock(&mutex_);

if (keys_.find(participant_id) == keys_.end()) {
keys_[participant_id] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
keys_[participant_id] =
rtc::make_ref_counted<ParticipantKeyHandler>(this);
}

auto key_handler = keys_[participant_id];
Expand Down Expand Up @@ -326,7 +336,8 @@ class DefaultKeyProviderImpl : public KeyProvider {
private:
mutable webrtc::Mutex mutex_;
KeyProviderOptions options_;
std::unordered_map<std::string, rtc::scoped_refptr<ParticipantKeyHandler>> keys_;
std::unordered_map<std::string, rtc::scoped_refptr<ParticipantKeyHandler>>
keys_;
};

enum FrameCryptionState {
Expand Down Expand Up @@ -361,19 +372,20 @@ class RTC_EXPORT FrameCryptorTransformer
kAesCbc,
};

explicit FrameCryptorTransformer(rtc::Thread* signaling_thread,
const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider);

explicit FrameCryptorTransformer(
rtc::Thread* signaling_thread,
const std::string participant_id,
MediaType type,
Algorithm algorithm,
rtc::scoped_refptr<KeyProvider> key_provider);
~FrameCryptorTransformer();
virtual void RegisterFrameCryptorTransformerObserver(
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer) {
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer) {
webrtc::MutexLock lock(&mutex_);
observer_ = observer;
}

virtual void UnRegisterFrameCryptorTransformerObserver() {
virtual void UnRegisterFrameCryptorTransformerObserver() {
webrtc::MutexLock lock(&mutex_);
observer_ = nullptr;
}
Expand Down Expand Up @@ -431,6 +443,7 @@ class RTC_EXPORT FrameCryptorTransformer

private:
TaskQueueBase* const signaling_thread_;
std::unique_ptr<rtc::Thread> thread_;
std::string participant_id_;
mutable webrtc::Mutex mutex_;
mutable webrtc::Mutex sink_mutex_;
Expand All @@ -445,7 +458,7 @@ class RTC_EXPORT FrameCryptorTransformer
rtc::scoped_refptr<KeyProvider> key_provider_;
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer_;
FrameCryptionState last_enc_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
};

} // namespace webrtc
Expand Down