Skip to content

Commit 7cad355

Browse files
committed
Socket RAII wrapper to prevent leaking socket
1 parent c1bf280 commit 7cad355

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

clickhouse/base/socket.cpp

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,51 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept {
144144
#endif
145145
}
146146

147+
const SOCKET INVALID_SOCKET = -1;
148+
149+
void CloseSocket(SOCKET socket) {
150+
if (socket == INVALID_SOCKET)
151+
return;
152+
153+
#if defined(_win_)
154+
closesocket(socket);
155+
#else
156+
close(socket);
157+
#endif
158+
}
159+
160+
struct SocketRAIIWrapper {
161+
SOCKET socket = INVALID_SOCKET;
162+
163+
~SocketRAIIWrapper() {
164+
CloseSocket(socket);
165+
}
166+
167+
SOCKET operator*() const {
168+
return socket;
169+
}
170+
171+
SOCKET release() {
172+
auto result = socket;
173+
socket = INVALID_SOCKET;
174+
175+
return result;
176+
}
177+
};
178+
147179
SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) {
148180
int last_err = 0;
149181
for (auto res = addr.Info(); res != nullptr; res = res->ai_next) {
150-
SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
182+
SocketRAIIWrapper s{socket(res->ai_family, res->ai_socktype, res->ai_protocol)};
151183

152-
if (s == -1) {
184+
if (*s == INVALID_SOCKET) {
153185
continue;
154186
}
155187

156-
SetNonBlock(s, true);
157-
SetTimeout(s, timeout_params);
188+
SetNonBlock(*s, true);
189+
SetTimeout(*s, timeout_params);
158190

159-
if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) {
191+
if (connect(*s, res->ai_addr, (int)res->ai_addrlen) != 0) {
160192
int err = getSocketErrorCode();
161193
if (
162194
err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK
@@ -165,7 +197,7 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
165197
#endif
166198
) {
167199
pollfd fd;
168-
fd.fd = s;
200+
fd.fd = *s;
169201
fd.events = POLLOUT;
170202
fd.revents = 0;
171203
ssize_t rval = Poll(&fd, 1, 5000);
@@ -175,18 +207,18 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
175207
}
176208
if (rval > 0) {
177209
socklen_t len = sizeof(err);
178-
getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
210+
getsockopt(*s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
179211

180212
if (!err) {
181-
SetNonBlock(s, false);
182-
return s;
213+
SetNonBlock(*s, false);
214+
return s.release();
183215
}
184216
last_err = err;
185217
}
186218
}
187219
} else {
188-
SetNonBlock(s, false);
189-
return s;
220+
SetNonBlock(*s, false);
221+
return s.release();
190222
}
191223
}
192224
if (last_err > 0) {
@@ -265,15 +297,15 @@ Socket::Socket(const NetworkAddress & addr)
265297
Socket::Socket(Socket&& other) noexcept
266298
: handle_(other.handle_)
267299
{
268-
other.handle_ = -1;
300+
other.handle_ = INVALID_SOCKET;
269301
}
270302

271303
Socket& Socket::operator=(Socket&& other) noexcept {
272304
if (this != &other) {
273305
Close();
274306

275307
handle_ = other.handle_;
276-
other.handle_ = -1;
308+
other.handle_ = INVALID_SOCKET;
277309
}
278310

279311
return *this;
@@ -284,14 +316,8 @@ Socket::~Socket() {
284316
}
285317

286318
void Socket::Close() {
287-
if (handle_ != -1) {
288-
#if defined(_win_)
289-
closesocket(handle_);
290-
#else
291-
close(handle_);
292-
#endif
293-
handle_ = -1;
294-
}
319+
CloseSocket(handle_);
320+
handle_ = INVALID_SOCKET;
295321
}
296322

297323
void Socket::SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept {

0 commit comments

Comments
 (0)