Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,27 @@ void SetNonBlock(SOCKET fd, bool value) {

void SetTimeout(SOCKET fd, const SocketTimeoutParams& timeout_params) {
#if defined(_unix_)
timeval recv_timeout { .tv_sec = timeout_params.recv_timeout.count(), .tv_usec = 0 };
setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout));
timeval recv_timeout { timeout_params.recv_timeout_s.count(), static_cast<int>(timeout_params.recv_timeout_us.count()) };
auto recv_ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout));

timeval send_timeout { .tv_sec = timeout_params.send_timeout.count(), .tv_usec = 0 };
setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout));
timeval send_timeout { timeout_params.send_timeout_s.count(), static_cast<int>(timeout_params.send_timeout_us.count()) };
auto send_ret = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout));

if (recv_ret == -1 || send_ret == -1) {
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to set socket timeout");
}
#else
const struct timeval recv_tv { timeout_params.recv_timeout_s.count(), timeout_params.recv_timeout_us.count()};
DWORD recv_timeout = recv_tv.tv_sec * 1000 + recv_tv.tv_usec / 1000;
auto recv_ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (const char*)&recv_timeout, sizeof(DWORD));

const struct timeval send_tv { timeout_params.send_timeout_s.count(), timeout_params.send_timeout_us.count()};
DWORD send_timeout = send_tv.tv_sec * 1000 + send_tv.tv_usec / 1000;
auto send_ret = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (const char*)&send_timeout, sizeof(DWORD));

if (recv_ret == SOCKET_ERROR || send_ret == SOCKET_ERROR) {
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to set socket timeout");
}
#endif
};

Expand Down Expand Up @@ -244,6 +260,10 @@ Socket::Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_pa
: handle_(SocketConnect(addr, timeout_params))
{}

Socket::Socket(const NetworkAddress & addr)
: handle_(SocketConnect(addr, SocketTimeoutParams{}))
{}

Socket::Socket(Socket&& other) noexcept
: handle_(other.handle_)
{
Expand Down
7 changes: 5 additions & 2 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,16 @@ class SocketFactory {


struct SocketTimeoutParams {
const std::chrono::seconds recv_timeout {0};
const std::chrono::seconds send_timeout {0};
const std::chrono::seconds recv_timeout_s {0};
const std::chrono::seconds send_timeout_s {0};
const std::chrono::microseconds recv_timeout_us{ 0 };
const std::chrono::microseconds send_timeout_us{ 0 };
};

class Socket : public SocketBase {
public:
Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);
Socket(const NetworkAddress& addr);
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;

Expand Down
16 changes: 11 additions & 5 deletions ut/socket_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ TEST(Socketcase, connecterror) {

std::this_thread::sleep_for(std::chrono::seconds(1));
try {
Socket socket(addr, SocketTimeoutParams {});
Socket socket(addr);
} catch (const std::system_error& e) {
FAIL();
}

std::this_thread::sleep_for(std::chrono::seconds(1));
server.stop();
try {
Socket socket(addr, SocketTimeoutParams {});
Socket socket(addr);
FAIL();
} catch (const std::system_error& e) {
ASSERT_NE(EINPROGRESS,e.code().value());
Expand All @@ -43,14 +43,20 @@ TEST(Socketcase, timeoutrecv) {

std::this_thread::sleep_for(std::chrono::seconds(1));
try {
Socket socket(addr, SocketTimeoutParams { .recv_timeout = Seconds(5), .send_timeout = Seconds(5) });
Socket socket(addr, SocketTimeoutParams { Seconds(5), Seconds(5) });

std::unique_ptr<InputStream> ptr_input_stream = socket.makeInputStream();
char buf[1024];
ptr_input_stream->Read(buf, sizeof(buf));

} catch (const std::system_error& e) {
ASSERT_EQ(EAGAIN, e.code().value());
}
catch (const std::system_error& e) {
#if defined(_unix_)
auto expected = EAGAIN;
#else
auto expected = WSAETIMEDOUT;
#endif
ASSERT_EQ(expected, e.code().value());
}

std::this_thread::sleep_for(std::chrono::seconds(1));
Expand Down