Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,25 @@ void SetNonBlock(SOCKET fd, bool value) {

void SetTimeout(SOCKET fd, const SocketTimeoutParams& timeout_params) {
#if defined(_unix_)
timeval recv_timeout { .tv_sec = timeout_params.recv_timeout.count(), .tv_usec = 0 };
setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout));
timeval recv_timeout{ timeout_params.recv_timeout.count() / 1000, static_cast<int>(timeout_params.recv_timeout.count() % 1000 * 1000) };
auto recv_ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout));

timeval send_timeout { .tv_sec = timeout_params.send_timeout.count(), .tv_usec = 0 };
setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout));
timeval send_timeout{ timeout_params.send_timeout.count() / 1000, static_cast<int>(timeout_params.send_timeout.count() % 1000 * 1000) };
auto send_ret = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout));

if (recv_ret == -1 || send_ret == -1) {
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to set socket timeout");
}
#else
DWORD recv_timeout = static_cast<DWORD>(timeout_params.recv_timeout.count());
auto recv_ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (const char*)&recv_timeout, sizeof(DWORD));

DWORD send_timeout = static_cast<DWORD>(timeout_params.send_timeout.count());
auto send_ret = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (const char*)&send_timeout, sizeof(DWORD));

if (recv_ret == SOCKET_ERROR || send_ret == SOCKET_ERROR) {
throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to set socket timeout");
}
#endif
};

Expand Down Expand Up @@ -244,6 +258,10 @@ Socket::Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_pa
: handle_(SocketConnect(addr, timeout_params))
{}

Socket::Socket(const NetworkAddress & addr)
: handle_(SocketConnect(addr, SocketTimeoutParams{}))
{}

Socket::Socket(Socket&& other) noexcept
: handle_(other.handle_)
{
Expand Down
5 changes: 3 additions & 2 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ class SocketFactory {


struct SocketTimeoutParams {
const std::chrono::seconds recv_timeout {0};
const std::chrono::seconds send_timeout {0};
std::chrono::milliseconds recv_timeout{ 0 };
std::chrono::milliseconds send_timeout{ 0 };
};

class Socket : public SocketBase {
public:
Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);
Socket(const NetworkAddress& addr);
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;

Expand Down
16 changes: 11 additions & 5 deletions ut/socket_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ TEST(Socketcase, connecterror) {

std::this_thread::sleep_for(std::chrono::seconds(1));
try {
Socket socket(addr, SocketTimeoutParams {});
Socket socket(addr);
} catch (const std::system_error& e) {
FAIL();
}

std::this_thread::sleep_for(std::chrono::seconds(1));
server.stop();
try {
Socket socket(addr, SocketTimeoutParams {});
Socket socket(addr);
FAIL();
} catch (const std::system_error& e) {
ASSERT_NE(EINPROGRESS,e.code().value());
Expand All @@ -43,14 +43,20 @@ TEST(Socketcase, timeoutrecv) {

std::this_thread::sleep_for(std::chrono::seconds(1));
try {
Socket socket(addr, SocketTimeoutParams { .recv_timeout = Seconds(5), .send_timeout = Seconds(5) });
Socket socket(addr, SocketTimeoutParams { Seconds(5), Seconds(5) });

std::unique_ptr<InputStream> ptr_input_stream = socket.makeInputStream();
char buf[1024];
ptr_input_stream->Read(buf, sizeof(buf));

} catch (const std::system_error& e) {
ASSERT_EQ(EAGAIN, e.code().value());
}
catch (const std::system_error& e) {
#if defined(_unix_)
auto expected = EAGAIN;
#else
auto expected = WSAETIMEDOUT;
#endif
ASSERT_EQ(expected, e.code().value());
}

std::this_thread::sleep_for(std::chrono::seconds(1));
Expand Down