From 327af03d3da032d40e12f1755c8a7ccdfdd48680 Mon Sep 17 00:00:00 2001 From: crs Date: Fri, 21 Jun 2002 16:19:08 +0000 Subject: [PATCH] fixed CTCPSocket::connect() to allow cancellation. --- net/CNetwork.cpp | 37 ++++++++++++++++++++++++++++++--- net/CNetwork.h | 7 +++++++ net/CTCPListenSocket.cpp | 1 + net/CTCPSocket.cpp | 45 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 84 insertions(+), 6 deletions(-) diff --git a/net/CNetwork.cpp b/net/CNetwork.cpp index 2122a74d..0dd37ee8 100644 --- a/net/CNetwork.cpp +++ b/net/CNetwork.cpp @@ -36,6 +36,7 @@ struct protoent FAR * (PASCAL FAR *CNetwork::getprotobynumber)(int proto); struct protoent FAR * (PASCAL FAR *CNetwork::getprotobyname)(const char FAR * name); int (PASCAL FAR *CNetwork::getsockerror)(void); int (PASCAL FAR *CNetwork::gethosterror)(void); +int (PASCAL FAR *CNetwork::setblocking)(CNetwork::Socket s, bool blocking); #if WINDOWS_LIKE @@ -207,9 +208,10 @@ CNetwork::init2( setfunc(WSACleanup, WSACleanup, int (PASCAL FAR *)(void)); setfunc(__WSAFDIsSet, __WSAFDIsSet, int (PASCAL FAR *)(CNetwork::Socket, fd_set FAR *)); setfunc(select, select, int (PASCAL FAR *)(int nfds, fd_set FAR *readfds, fd_set FAR *writefds, fd_set FAR *exceptfds, const struct timeval FAR *timeout)); - poll = poll2; - read = read2; - write = write2; + poll = poll2; + read = read2; + write = write2; + setblocking = setblocking2; s_networkModule = module; } @@ -295,11 +297,19 @@ CNetwork::write2(Socket s, const void FAR* buf, size_t len) return send(s, buf, len, 0); } +int PASCAL FAR +CNetwork::setblocking2(CNetwork::Socket s, bool blocking) +{ + int flag = blocking ? 0 : 1; + return ioctlsocket(s, FIONBIO, &flag); +} + #endif #if UNIX_LIKE #include +#include #include #include @@ -352,6 +362,26 @@ mygethostname(char* name, int namelen) return gethostname(name, namelen); } +static +int +mysetblocking(CNetwork::Socket s, bool blocking) +{ + int mode = fcntl(s, F_GETFL, 0); + if (mode == -1) { + return -1; + } + if (blocking) { + mode &= ~O_NDELAY; + } + else { + mode |= O_NDELAY; + } + if (fcntl(s, F_SETFL, mode) < 0) { + return -1; + } + return 0; +} + const int CNetwork::Error = -1; const CNetwork::Socket CNetwork::Null = -1; @@ -388,6 +418,7 @@ CNetwork::init() setfunc(getprotobyname, getprotobyname, struct protoent FAR * (PASCAL FAR *)(const char FAR * name)); setfunc(getsockerror, myerrno, int (PASCAL FAR *)(void)); setfunc(gethosterror, myherrno, int (PASCAL FAR *)(void)); + setfunc(setblocking, mysetblocking, int (PASCAL FAR *)(Socket, bool)); } void diff --git a/net/CNetwork.h b/net/CNetwork.h index cf090ea1..2f2a7883 100644 --- a/net/CNetwork.h +++ b/net/CNetwork.h @@ -84,8 +84,10 @@ public: enum { #if WINDOWS_LIKE kEADDRINUSE = WSAEADDRINUSE, + kECONNECTING = WSAEWOULDBLOCK, #elif UNIX_LIKE kEADDRINUSE = EADDRINUSE, + kECONNECTING = EINPROGRESS, #endif kNone = 0 }; @@ -139,12 +141,17 @@ public: static int (PASCAL FAR *getsockerror)(void); static int (PASCAL FAR *gethosterror)(void); + // convenience functions (only available after init()) + + static int (PASCAL FAR *setblocking)(CNetwork::Socket s, bool blocking); + #if WINDOWS_LIKE private: static void init2(HMODULE); static int PASCAL FAR poll2(PollEntry[], int nfds, int timeout); static ssize_t PASCAL FAR read2(Socket s, void FAR * buf, size_t len); static ssize_t PASCAL FAR write2(Socket s, const void FAR * buf, size_t len); + static int PASCAL FAR setblocking2(CNetwork::Socket s, bool blocking); static int (PASCAL FAR *WSACleanup)(void); static int (PASCAL FAR *__WSAFDIsSet)(CNetwork::Socket, fd_set FAR *); static int (PASCAL FAR *select)(int nfds, fd_set FAR *readfds, fd_set FAR *writefds, fd_set FAR *exceptfds, const struct timeval FAR *timeout); diff --git a/net/CTCPListenSocket.cpp b/net/CTCPListenSocket.cpp index fa7950ce..8e8f36d9 100644 --- a/net/CTCPListenSocket.cpp +++ b/net/CTCPListenSocket.cpp @@ -45,6 +45,7 @@ CTCPListenSocket::bind(const CNetworkAddress& addr) IDataSocket* CTCPListenSocket::accept() { + // accept asynchronously so we can check for cancellation CNetwork::PollEntry pfds[1]; pfds[0].fd = m_fd; pfds[0].events = CNetwork::kPOLLIN; diff --git a/net/CTCPSocket.cpp b/net/CTCPSocket.cpp index 8b3c5bc9..c6e1970e 100644 --- a/net/CTCPSocket.cpp +++ b/net/CTCPSocket.cpp @@ -109,13 +109,52 @@ CTCPSocket::close() void CTCPSocket::connect(const CNetworkAddress& addr) { - CThread::testCancel(); + // connect asynchronously so we can check for cancellation + CNetwork::setblocking(m_fd, false); if (CNetwork::connect(m_fd, addr.getAddress(), addr.getAddressLength()) == CNetwork::Error) { - CThread::testCancel(); - throw XSocketConnect(); + // check for failure + if (CNetwork::getsockerror() != CNetwork::kECONNECTING) { + CNetwork::setblocking(m_fd, true); + throw XSocketConnect(); + } + + // wait for connection or failure + CNetwork::PollEntry pfds[1]; + pfds[0].fd = m_fd; + pfds[0].events = CNetwork::kPOLLOUT; + for (;;) { + CThread::testCancel(); + const int status = CNetwork::poll(pfds, 1, 10); + if (status > 0) { + if ((pfds[0].revents & (CNetwork::kPOLLERR | + CNetwork::kPOLLNVAL)) != 0) { + // connection failed + CNetwork::setblocking(m_fd, true); + throw XSocketConnect(); + } + if ((pfds[0].revents & CNetwork::kPOLLOUT) != 0) { + int error; + CNetwork::AddressLength size = sizeof(error); + if (CNetwork::getsockopt(m_fd, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), + &size) == CNetwork::Error || + error != 0) { + // connection failed + CNetwork::setblocking(m_fd, true); + throw XSocketConnect(); + } + + // connected! + break; + } + } + } } + // back to blocking + CNetwork::setblocking(m_fd, true); + // start servicing the socket m_connected = kReadWrite; m_thread = new CThread(new TMethodJob(