Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,4 +421,104 @@ void EthernetConnection::recvMessages() {
}
}

// IndirectConnection

VortexScheduler::VortexScheduler(std::shared_ptr<Context> context, uint64_t granularity, Device device) : granularity_(granularity), streams_({std::make_shared<CudaStreamWithFlags>(0), std::make_shared<CudaStreamWithFlags>(0)}), forwarding_device_(device) {
if (device.type != DeviceType::GPU) {
throw std::runtime_error("The forwarding device must be a GPU");
}
int origin_device;
MSCCLPP_CUDATHROW(cudaGetDevice(&origin_device));
MSCCLPP_CUDATHROW(cudaSetDevice(device.id));
void* buf1, *buf2;
MSCCLPP_CUDATHROW(cudaMalloc((void**)&buf1, granularity));
MSCCLPP_CUDATHROW(cudaMalloc((void**)&buf2, granularity));
buf_ptr_ = std::make_shared<mscclpp::DoubleBuffer>(
context->registerMemory(buf1, granularity, mscclpp::Transport::CudaIpc),
context->registerMemory(buf2, granularity, mscclpp::Transport::CudaIpc)
);
MSCCLPP_CUDATHROW(cudaSetDevice(origin_device));
}

VortexScheduler::~VortexScheduler() {
if (buf_ptr_) {
int origin_device;
cudaGetDevice(&origin_device);
cudaSetDevice(forwarding_device_.id);
cudaFree(buf_ptr_->next_get().data());
cudaFree(buf_ptr_->next_put().data());
cudaSetDevice(origin_device);
}
}

std::vector<IOTask> VortexScheduler::produce_tasks(void *dst, void *src, uint64_t size) {
std::vector<IOTask> tasks_;
for (uint64_t i = 0; i < size; i += granularity_) {
tasks_.push_back({dst + i, src + i, std::min(granularity_, size - i)});
}
return tasks_;
}

void VortexScheduler::launch(const std::vector<IOTask>& tasks) {
if (tasks.empty()) {
return;
}

cudaEvent_t event0, event1;
MSCCLPP_CUDATHROW(cudaEventCreateWithFlags(&event0, cudaEventBlockingSync | cudaEventDisableTiming));
MSCCLPP_CUDATHROW(cudaEventCreateWithFlags(&event1, cudaEventBlockingSync | cudaEventDisableTiming));

MSCCLPP_CUDATHROW(cudaMemcpyAsync(buf_ptr_->next_put().data(), tasks.front().src, tasks.front().size, cudaMemcpyDefault, *streams_[0]));
MSCCLPP_CUDATHROW(cudaEventRecord(event0, *streams_[0]));
buf_ptr_->produce();

for (uint64_t i = 1; i < tasks.size(); ++i) {
MSCCLPP_CUDATHROW(cudaStreamWaitEvent(*streams_[1], event0, 0));
MSCCLPP_CUDATHROW(cudaStreamWaitEvent(*streams_[0], event1, 0));
MSCCLPP_CUDATHROW(cudaMemcpyAsync(tasks[i - 1].dst, buf_ptr_->next_get().data(), tasks[i - 1].size, cudaMemcpyDefault, *streams_[1]));
MSCCLPP_CUDATHROW(cudaEventRecord(event1, *streams_[1]));
MSCCLPP_CUDATHROW(cudaMemcpyAsync(buf_ptr_->next_put().data(), tasks[i].src, tasks[i].size, cudaMemcpyDefault, *streams_[0]));
MSCCLPP_CUDATHROW(cudaEventRecord(event0, *streams_[0]));

buf_ptr_->consume();
buf_ptr_->produce();
}

MSCCLPP_CUDATHROW(cudaStreamWaitEvent(*streams_[1], event0, 0));
MSCCLPP_CUDATHROW(cudaMemcpyAsync(tasks.back().dst, buf_ptr_->next_get().data(), tasks.back().size, cudaMemcpyDefault, *streams_[1]));
MSCCLPP_CUDATHROW(cudaEventRecord(event1, *streams_[1]));
buf_ptr_->consume();

MSCCLPP_CUDATHROW(cudaEventDestroy(event0));
MSCCLPP_CUDATHROW(cudaEventDestroy(event1));
}

void VortexScheduler::sync() {
MSCCLPP_CUDATHROW(cudaStreamSynchronize(*streams_[0]));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(*streams_[1]));
}

void IndirectConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
if (dstOffset + size > dst.size() || srcOffset + size > src.size()) {
throw Error("IndirectionConnection::write out of bounds", ErrorCode::InvalidUsage);
}
auto tasks = scheduler_ptr_->produce_tasks(dst.data() + dstOffset, src.data() + srcOffset, size);
scheduler_ptr_->launch(tasks);
}

void IndirectConnection::flush(int64_t timeoutUsec) {
if (timeoutUsec != -1) {
throw std::runtime_error("IndirectConnection does not support timeout in flush");
}
scheduler_ptr_->sync();
}
Transport IndirectConnection::transport() const {
return Transport::CudaIpc;
}
Transport IndirectConnection::remoteTransport() const {
return Transport::CudaIpc;
}


} // namespace mscclpp
67 changes: 67 additions & 0 deletions src/include/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,73 @@ class EthernetConnection : public Connection {
void flush(int64_t timeoutUsec) override;
};

class BufferResource {
public:
virtual ~BufferResource() = default;
virtual RegisteredMemory next_get() = 0;
virtual RegisteredMemory next_put() = 0;
virtual void produce() = 0;
virtual void consume() = 0;
};

class DoubleBuffer : public BufferResource {
std::array<RegisteredMemory, 2> bufs_;
int cur_{0};

public:
DoubleBuffer(RegisteredMemory buf1, RegisteredMemory buf2) : bufs_({buf1, buf2}) {}
RegisteredMemory next_get() override { return bufs_[cur_]; }
RegisteredMemory next_put() override { return bufs_[cur_ ^ 1]; }
void produce() override { cur_ ^= 1; }
void consume() override {}
};

class IOTask {
public:
void *dst, *src;
uint64_t size;
IOTask(void *dst_, void *src_, uint64_t size_) : dst(dst_), src(src_), size(size_) {}
};

class Scheduler {
public:
virtual std::vector<IOTask> produce_tasks(void *dst, void *src, uint64_t size) = 0;
virtual void launch(const std::vector<IOTask>& tasks) = 0;
virtual void sync() = 0;
};

class VortexScheduler : public Scheduler {
std::shared_ptr<DoubleBuffer> buf_ptr_;
uint64_t granularity_;
std::array<std::shared_ptr<CudaStreamWithFlags>, 2> streams_;
Device forwarding_device_;

public:
VortexScheduler(std::shared_ptr<Context> context, uint64_t granularity, Device device);
~VortexScheduler();
std::vector<IOTask> produce_tasks(void *dst, void *src, uint64_t size) override;
void launch(const std::vector<IOTask>& tasks) override;
void sync() override;
};

class IndirectConnection : public Connection {
std::shared_ptr<Scheduler> scheduler_ptr_;

public:
IndirectConnection(std::shared_ptr<Context> context,
Endpoint localEndpoint,
std::shared_ptr<Scheduler> scheduler) : Connection(context, localEndpoint), scheduler_ptr_(scheduler) {}
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void flush(int64_t timeoutUsec = -1) override;
Transport transport() const override;
Transport remoteTransport() const override;

virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override {
throw std::runtime_error("IndirectConnection does not support updateAndSync");
}
};

} // namespace mscclpp

#endif // MSCCLPP_CONNECTION_HPP_
10 changes: 10 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ add_subdirectory(mscclpp-test)

# Performance tests
add_subdirectory(perf)

# Indirect connection tests
add_executable(indirect_connection_test indirect_connection_test.cc)
message (STATUS "Using CMAKE_SOURCE_DIR: ${CMAKE_SOURCE_DIR}")

target_include_directories(indirect_connection_test PRIVATE
${CMAKE_SOURCE_DIR}/include/mscclpp/src/include
)
target_link_libraries(indirect_connection_test ${TEST_LIBS_COMMON} ${TEST_LIBS_GTEST})
target_include_directories(indirect_connection_test ${TEST_INC_COMMON} ${TEST_INC_INTERNAL})
151 changes: 151 additions & 0 deletions test/indirect_connection_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <algorithm>
#include <connection.hpp>
#include <random>
#include <vector>

class IndirectConnectionTest : public ::testing::Test {
protected:
void SetUp() override { ctx = mscclpp::Context::create(); }
std::shared_ptr<mscclpp::Context> ctx;
bool validate_answer = false;
};

TEST_F(IndirectConnectionTest, CPUGPUDataTransfer) {
mscclpp::Device dst(mscclpp::DeviceType::GPU, 1);
mscclpp::Device fwd(mscclpp::DeviceType::GPU, 2);
uint64_t granularity = 20'000'000;
uint64_t n = uint64_t(granularity * 30);

// generate random data
int* dummy;
cudaMallocHost((void**)&dummy, n * sizeof(int));
std::mt19937 gen(std::random_device{}());
std::generate(dummy, dummy + n, [&]() { return gen() % (1 << 30); });

// reserve memory on destination GPU
int* dummy_device;
cudaSetDevice(dst.id);
cudaMalloc((void**)&dummy_device, n * sizeof(int));

// enable GPU peer access
cudaSetDevice(fwd.id);
int canAccess;
cudaDeviceCanAccessPeer(&canAccess, fwd.id, dst.id);
if (canAccess) {
std::cout << "Enabling peer access from " << fwd.id << " to " << dst.id << std::endl;
cudaDeviceEnablePeerAccess(dst.id, 0);
}

// create local endpoint
auto localEndpoint = ctx->createEndpoint(mscclpp::EndpointConfig());

// register scheduler
auto scheduler_ptr = std::make_shared<mscclpp::VortexScheduler>(ctx, granularity, fwd);

// launch writes and measure performance
for (uint64_t _ = 0; _ < 4; _ ++) {
auto connection = mscclpp::IndirectConnection(ctx, localEndpoint, scheduler_ptr);
auto start = std::chrono::high_resolution_clock::now();
connection.write(
ctx->registerMemory(dummy_device, n * sizeof(int), mscclpp::Transport::CudaIpc),
0,
ctx->registerMemory(dummy, n * sizeof(int), mscclpp::NoTransports),
0,
n * sizeof(int)
);
connection.flush();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration<double>(end - start);

std::cout << "Time: " << duration.count() << " seconds" << std::endl;
std::cout << "Bandwidth: "
<< (n * sizeof(int)) / duration.count() / (1e9) << " GB/s" << std::endl;

}

// validate
if (validate_answer) {
int* validate;
cudaMallocHost((void**)&validate, n * sizeof(int));
cudaMemcpy(validate, dummy_device, n * sizeof(int), cudaMemcpyDefault);

for (uint64_t i = 0; i < n; ++i) {
EXPECT_EQ(validate[i], dummy[i]) << "Mismatch at index " << i;
}
cudaFreeHost(validate);
}
// cleanup
cudaFree(dummy_device);
cudaFreeHost(dummy);
}

TEST_F(IndirectConnectionTest, GPUCPUDataTransfer) {
mscclpp::Device src(mscclpp::DeviceType::GPU, 1);
mscclpp::Device fwd(mscclpp::DeviceType::GPU, 2);
uint64_t granularity = 20'000'000;
uint64_t n = uint64_t(granularity * 30);

// generate random data
int* dst_host, * dummy;
cudaMallocHost((void**)&dummy, n * sizeof(int));
cudaMallocHost((void**)&dst_host, n * sizeof(int));
std::mt19937 gen(std::random_device{}());
std::generate(dummy, dummy + n, [&]() { return gen() % (1 << 30); });

// reserve memory on source GPU
int* dummy_device;
cudaSetDevice(src.id);
cudaMalloc((void**)&dummy_device, n * sizeof(int));
cudaMemcpy(dummy_device, dummy, n * sizeof(int), cudaMemcpyHostToDevice);

// enable GPU peer access
cudaSetDevice(src.id);
int canAccess;
cudaDeviceCanAccessPeer(&canAccess, src.id, fwd.id);
if (canAccess) {
std::cout << "Enabling peer access from " << src.id << " to " << fwd.id << std::endl;
cudaDeviceEnablePeerAccess(fwd.id, 0);
}
auto localEndpoint = ctx->createEndpoint(mscclpp::EndpointConfig());

// register scheduler
auto scheduler_ptr = std::make_shared<mscclpp::VortexScheduler>(ctx, granularity, fwd);

// launch writes and measure performance
for (uint64_t _ = 0; _ < 4; _ ++) {
auto connection = mscclpp::IndirectConnection(ctx, localEndpoint, scheduler_ptr);
auto start = std::chrono::high_resolution_clock::now();
connection.write(
ctx->registerMemory(dst_host, n * sizeof(int), mscclpp::NoTransports),
0,
ctx->registerMemory(dummy_device, n * sizeof(int), mscclpp::Transport::CudaIpc),
0,
n * sizeof(int)
);
connection.flush();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration<double>(end - start);

std::cout << "Time: " << duration.count() << " seconds" << std::endl;
std::cout << "Bandwidth: "
<< (n * sizeof(int)) / duration.count() / (1e9) << " GB/s" << std::endl;
}

// validate
if (validate_answer) {
for (uint64_t i = 0; i < n; ++i) {
EXPECT_EQ(dummy[i], dst_host[i]) << "Mismatch at index " << i;
}
}

// cleanup
cudaFree(dummy_device);
cudaFreeHost(dummy);
cudaFreeHost(dst_host);
}
Loading