diff --git a/doc/newsfragments/close-failed-handshake-connections.bugfix b/doc/newsfragments/close-failed-handshake-connections.bugfix new file mode 100644 index 00000000..0dc8c16c --- /dev/null +++ b/doc/newsfragments/close-failed-handshake-connections.bugfix @@ -0,0 +1,6 @@ +SECURITY ISSUE + +Barrier will now correctly close connections when the app-level handshake fails (fixes CVE-2021-42075). + +Previously repeated failing connections would leak file descriptors leading to Barrier being unable +to receive new connections from clients. diff --git a/doc/newsfragments/enforce-maximum-message-length.bugfix b/doc/newsfragments/enforce-maximum-message-length.bugfix new file mode 100644 index 00000000..81ec2ba0 --- /dev/null +++ b/doc/newsfragments/enforce-maximum-message-length.bugfix @@ -0,0 +1,6 @@ +SECURITY ISSUE + +Barrier will now enforce a maximum length of input messages (fixes CVE-2021-42076). + +Previously it was possible for a malicious client or server to send excessive length messages +leading to denial of service by resource exhaustion. diff --git a/doc/newsfragments/fix-crash-on-ssl-hello.bugfix b/doc/newsfragments/fix-crash-on-ssl-hello.bugfix new file mode 100644 index 00000000..30bb0603 --- /dev/null +++ b/doc/newsfragments/fix-crash-on-ssl-hello.bugfix @@ -0,0 +1,4 @@ +SECURITY ISSUE + +Fixed a bug which caused Barrier to crash when disconnecting a TCP session just after sending Hello message. +This bug allowed an unauthenticated attacker to crash Barrier with only network access. diff --git a/doc/newsfragments/ssl-corrupted-data.bugfix b/doc/newsfragments/ssl-corrupted-data.bugfix new file mode 100644 index 00000000..db8bbf86 --- /dev/null +++ b/doc/newsfragments/ssl-corrupted-data.bugfix @@ -0,0 +1,2 @@ +Fixed a bug in SSL implementation that caused invalid data occasionally being sent to clients +under heavy load. diff --git a/src/lib/barrier/PacketStreamFilter.cpp b/src/lib/barrier/PacketStreamFilter.cpp index 16f0fe76..b6befd66 100644 --- a/src/lib/barrier/PacketStreamFilter.cpp +++ b/src/lib/barrier/PacketStreamFilter.cpp @@ -17,6 +17,7 @@ */ #include "barrier/PacketStreamFilter.h" +#include "barrier/protocol_types.h" #include "base/IEventQueue.h" #include "mt/Lock.h" #include "base/TMethodEventJob.h" @@ -133,8 +134,7 @@ PacketStreamFilter::isReadyNoLock() const return (m_size != 0 && m_buffer.getSize() >= m_size); } -void -PacketStreamFilter::readPacketSize() +bool PacketStreamFilter::readPacketSize() { // note -- m_mutex must be locked on entry @@ -146,7 +146,13 @@ PacketStreamFilter::readPacketSize() ((UInt32)buffer[1] << 16) | ((UInt32)buffer[2] << 8) | (UInt32)buffer[3]; + + if (m_size > PROTOCOL_MAX_MESSAGE_LENGTH) { + m_events->addEvent(Event(m_events->forIStream().inputFormatError(), getEventTarget())); + return false; + } } + return true; } bool @@ -160,13 +166,17 @@ PacketStreamFilter::readMore() UInt32 n = getStream()->read(buffer, sizeof(buffer)); while (n > 0) { m_buffer.write(buffer, n); + + // if we don't yet have the next packet size then get it, if possible. + // Note that we can't wait for whole pending data to arrive because it may be huge in + // case of malicious or erroneous peer. + if (!readPacketSize()) { + break; + } + n = getStream()->read(buffer, sizeof(buffer)); } - // if we don't yet have the next packet size then get it, - // if possible. - readPacketSize(); - // note if we now have a whole packet bool isReady = isReadyNoLock(); diff --git a/src/lib/barrier/PacketStreamFilter.h b/src/lib/barrier/PacketStreamFilter.h index bcbd604b..e6f1a37d 100644 --- a/src/lib/barrier/PacketStreamFilter.h +++ b/src/lib/barrier/PacketStreamFilter.h @@ -47,7 +47,9 @@ protected: private: bool isReadyNoLock() const; - void readPacketSize(); + + // returns false on erroneous packet size + bool readPacketSize(); bool readMore(); private: diff --git a/src/lib/barrier/ProtocolUtil.cpp b/src/lib/barrier/ProtocolUtil.cpp index e742687f..c1be0321 100644 --- a/src/lib/barrier/ProtocolUtil.cpp +++ b/src/lib/barrier/ProtocolUtil.cpp @@ -19,6 +19,8 @@ #include "barrier/ProtocolUtil.h" #include "io/IStream.h" #include "base/Log.h" +#include "barrier/protocol_types.h" +#include "barrier/XBarrier.h" #include "common/stdvector.h" #include "base/String.h" @@ -159,6 +161,10 @@ ProtocolUtil::vreadf(barrier::IStream* stream, const char* fmt, va_list args) (static_cast(buffer[2]) << 8) | static_cast(buffer[3]); + if (n > PROTOCOL_MAX_LIST_LENGTH) { + throw XBadClient("Too long message received"); + } + // convert it void* v = va_arg(args, void*); switch (len) { @@ -211,6 +217,10 @@ ProtocolUtil::vreadf(barrier::IStream* stream, const char* fmt, va_list args) (static_cast(buffer[2]) << 8) | static_cast(buffer[3]); + if (len > PROTOCOL_MAX_STRING_LENGTH) { + throw XBadClient("Too long message received"); + } + // use a fixed size buffer if its big enough const bool useFixed = (len <= sizeof(buffer)); diff --git a/src/lib/barrier/protocol_types.h b/src/lib/barrier/protocol_types.h index bc5e0377..6acee26f 100644 --- a/src/lib/barrier/protocol_types.h +++ b/src/lib/barrier/protocol_types.h @@ -20,6 +20,8 @@ #include "base/EventTypes.h" +#include + // protocol version number // 1.0: initial protocol // 1.1: adds KeyCode to key press, release, and repeat @@ -51,6 +53,12 @@ static const double kKeepAlivesUntilDeath = 3.0; static const double kHeartRate = -1.0; static const double kHeartBeatsUntilDeath = 3.0; +// Messages of very large size indicate a likely protocol error. We don't parse such messages and +// drop connection instead. Note that e.g. the clipboard messages are already limited to 32kB. +static constexpr std::uint32_t PROTOCOL_MAX_MESSAGE_LENGTH = 4 * 1024 * 1024; +static constexpr std::uint32_t PROTOCOL_MAX_LIST_LENGTH = 1024 * 1024; +static constexpr std::uint32_t PROTOCOL_MAX_STRING_LENGTH = 1024 * 1024; + // direction constants enum EDirection { kNoDirection, diff --git a/src/lib/base/EventTypes.cpp b/src/lib/base/EventTypes.cpp index 2ba20778..7a41ded2 100644 --- a/src/lib/base/EventTypes.cpp +++ b/src/lib/base/EventTypes.cpp @@ -56,6 +56,7 @@ REGISTER_EVENT(IStream, outputFlushed) REGISTER_EVENT(IStream, outputError) REGISTER_EVENT(IStream, inputShutdown) REGISTER_EVENT(IStream, outputShutdown) +REGISTER_EVENT(IStream, inputFormatError) // // IpcClient diff --git a/src/lib/base/EventTypes.h b/src/lib/base/EventTypes.h index f81617e0..148fa2c8 100644 --- a/src/lib/base/EventTypes.h +++ b/src/lib/base/EventTypes.h @@ -133,6 +133,11 @@ public: */ Event::Type outputShutdown(); + /** Get input format error event type + + This is sent when a stream receives an irrecoverable input format error. + */ + Event::Type inputFormatError(); //@} private: @@ -141,6 +146,7 @@ private: Event::Type m_outputError; Event::Type m_inputShutdown; Event::Type m_outputShutdown; + Event::Type m_inputFormatError; }; class IpcClientEvents : public EventTypes { diff --git a/src/lib/client/ServerProxy.cpp b/src/lib/client/ServerProxy.cpp index c067f132..1e5c339b 100644 --- a/src/lib/client/ServerProxy.cpp +++ b/src/lib/client/ServerProxy.cpp @@ -26,6 +26,7 @@ #include "barrier/ProtocolUtil.h" #include "barrier/option_types.h" #include "barrier/protocol_types.h" +#include "barrier/XBarrier.h" #include "io/IStream.h" #include "base/Log.h" #include "base/IEventQueue.h" @@ -124,17 +125,27 @@ ServerProxy::handleData(const Event&, void*) // parse message LOG((CLOG_DEBUG2 "msg from server: %c%c%c%c", code[0], code[1], code[2], code[3])); - switch ((this->*m_parser)(code)) { - case kOkay: - break; + try { + switch ((this->*m_parser)(code)) { + case kOkay: + break; - case kUnknown: - LOG((CLOG_ERR "invalid message from server: %c%c%c%c", code[0], code[1], code[2], code[3])); + case kUnknown: + LOG((CLOG_ERR "invalid message from server: %c%c%c%c", code[0], code[1], code[2], code[3])); + m_client->disconnect("invalid message from server"); + return; + + case kDisconnect: + return; + } + } catch (const XBadClient& e) { + // TODO: disconnect handling is currently dispersed across both parseMessage() and + // handleData() functions, we should collect that to a single place + + LOG((CLOG_ERR "protocol error from server: %s", e.what())); + ProtocolUtil::writef(m_stream, kMsgEBad); m_client->disconnect("invalid message from server"); return; - - case kDisconnect: - return; } // next message diff --git a/src/lib/net/SecureSocket.cpp b/src/lib/net/SecureSocket.cpp index 855e16bb..f31574f0 100644 --- a/src/lib/net/SecureSocket.cpp +++ b/src/lib/net/SecureSocket.cpp @@ -40,6 +40,7 @@ #define MAX_ERROR_SIZE 65535 +static const std::size_t MAX_INPUT_BUFFER_SIZE = 1024 * 1024; static const float s_retryDelay = 0.01f; enum { @@ -103,6 +104,8 @@ SecureSocket::close() void SecureSocket::freeSSLResources() { + std::lock_guard ssl_lock{ssl_mutex_}; + if (m_ssl->m_ssl != NULL) { SSL_shutdown(m_ssl->m_ssl); SSL_free(m_ssl->m_ssl); @@ -156,7 +159,7 @@ SecureSocket::secureAccept() TCPSocket::EJobResult SecureSocket::doRead() { - static UInt8 buffer[4096]; + UInt8 buffer[4096]; memset(buffer, 0, sizeof(buffer)); int bytesRead = 0; int status = 0; @@ -180,7 +183,11 @@ SecureSocket::doRead() // slurp up as much as possible do { m_inputBuffer.write(buffer, bytesRead); - + + if (m_inputBuffer.getSize() > MAX_INPUT_BUFFER_SIZE) { + break; + } + status = secureRead(buffer, sizeof(buffer), bytesRead); if (status < 0) { return kBreak; @@ -211,11 +218,6 @@ SecureSocket::doRead() TCPSocket::EJobResult SecureSocket::doWrite() { - static bool s_retry = false; - static int s_retrySize = 0; - static std::unique_ptr s_staticBuffer; - static std::size_t s_staticBufferSize = 0; - // write data int bufferSize = 0; int bytesWrote = 0; @@ -224,16 +226,16 @@ SecureSocket::doWrite() if (!isSecureReady()) return kRetry; - if (s_retry) { - bufferSize = s_retrySize; + if (do_write_retry_) { + bufferSize = do_write_retry_size_; } else { bufferSize = m_outputBuffer.getSize(); - if (bufferSize > s_staticBufferSize) { - s_staticBuffer.reset(new char[bufferSize]); - s_staticBufferSize = bufferSize; + if (bufferSize > do_write_retry_buffer_size_) { + do_write_retry_buffer_.reset(new char[bufferSize]); + do_write_retry_buffer_size_ = bufferSize; } if (bufferSize > 0) { - memcpy(s_staticBuffer.get(), m_outputBuffer.peek(bufferSize), bufferSize); + std::memcpy(do_write_retry_buffer_.get(), m_outputBuffer.peek(bufferSize), bufferSize); } } @@ -241,14 +243,14 @@ SecureSocket::doWrite() return kRetry; } - status = secureWrite(s_staticBuffer.get(), bufferSize, bytesWrote); + status = secureWrite(do_write_retry_buffer_.get(), bufferSize, bytesWrote); if (status > 0) { - s_retry = false; + do_write_retry_ = false; } else if (status < 0) { return kBreak; } else if (status == 0) { - s_retry = true; - s_retrySize = bufferSize; + do_write_retry_ = true; + do_write_retry_size_ = bufferSize; return kNew; } @@ -263,16 +265,16 @@ SecureSocket::doWrite() int SecureSocket::secureRead(void* buffer, int size, int& read) { + std::lock_guard ssl_lock{ssl_mutex_}; + if (m_ssl->m_ssl != NULL) { LOG((CLOG_DEBUG2 "reading secure socket")); read = SSL_read(m_ssl->m_ssl, buffer, size); - - static int retry; // Check result will cleanup the connection in the case of a fatal - checkResult(read, retry); - - if (retry) { + checkResult(read, secure_read_retry_); + + if (secure_read_retry_) { return 0; } @@ -289,17 +291,17 @@ SecureSocket::secureRead(void* buffer, int size, int& read) int SecureSocket::secureWrite(const void* buffer, int size, int& wrote) { + std::lock_guard ssl_lock{ssl_mutex_}; + if (m_ssl->m_ssl != NULL) { LOG((CLOG_DEBUG2 "writing secure socket:%p", this)); wrote = SSL_write(m_ssl->m_ssl, buffer, size); - - static int retry; // Check result will cleanup the connection in the case of a fatal - checkResult(wrote, retry); + checkResult(wrote, secure_write_retry_); - if (retry) { + if (secure_write_retry_) { return 0; } @@ -322,6 +324,8 @@ SecureSocket::isSecureReady() void SecureSocket::initSsl(bool server) { + std::lock_guard ssl_lock{ssl_mutex_}; + m_ssl = new Ssl(); m_ssl->m_context = NULL; m_ssl->m_ssl = NULL; @@ -331,6 +335,8 @@ SecureSocket::initSsl(bool server) bool SecureSocket::loadCertificates(std::string& filename) { + std::lock_guard ssl_lock{ssl_mutex_}; + if (filename.empty()) { showError("ssl certificate is not specified"); return false; @@ -373,6 +379,8 @@ bool SecureSocket::loadCertificates(std::string& filename) void SecureSocket::initContext(bool server) { + // ssl_mutex_ is assumed to be acquired + SSL_library_init(); const SSL_METHOD* method; @@ -410,6 +418,8 @@ SecureSocket::initContext(bool server) void SecureSocket::createSSL() { + // ssl_mutex_ is assumed to be acquired + // I assume just one instance is needed // get new SSL state with context if (m_ssl->m_ssl == NULL) { @@ -421,6 +431,8 @@ SecureSocket::createSSL() int SecureSocket::secureAccept(int socket) { + std::lock_guard ssl_lock{ssl_mutex_}; + createSSL(); // set connection socket to SSL state @@ -428,10 +440,8 @@ SecureSocket::secureAccept(int socket) LOG((CLOG_DEBUG2 "accepting secure socket")); int r = SSL_accept(m_ssl->m_ssl); - - static int retry; - checkResult(r, retry); + checkResult(r, secure_accept_retry_); if (isFatal()) { // tell user and sleep so the socket isn't hammered. @@ -439,12 +449,12 @@ SecureSocket::secureAccept(int socket) LOG((CLOG_INFO "client connection may not be secure")); m_secureReady = false; ARCH->sleep(1); - retry = 0; + secure_accept_retry_ = 0; return -1; // Failed, error out } // If not fatal and no retry, state is good - if (retry == 0) { + if (secure_accept_retry_ == 0) { m_secureReady = true; LOG((CLOG_INFO "accepted secure socket")); if (CLOG->getFilter() >= kDEBUG1) { @@ -455,7 +465,7 @@ SecureSocket::secureAccept(int socket) } // If not fatal and retry is set, not ready, and return retry - if (retry > 0) { + if (secure_accept_retry_ > 0) { LOG((CLOG_DEBUG2 "retry accepting secure socket")); m_secureReady = false; ARCH->sleep(s_retryDelay); @@ -470,6 +480,8 @@ SecureSocket::secureAccept(int socket) int SecureSocket::secureConnect(int socket) { + std::lock_guard ssl_lock{ssl_mutex_}; + createSSL(); // attach the socket descriptor @@ -477,26 +489,24 @@ SecureSocket::secureConnect(int socket) LOG((CLOG_DEBUG2 "connecting secure socket")); int r = SSL_connect(m_ssl->m_ssl); - - static int retry; - checkResult(r, retry); + checkResult(r, secure_connect_retry_); if (isFatal()) { LOG((CLOG_ERR "failed to connect secure socket")); - retry = 0; + secure_connect_retry_ = 0; return -1; } // If we should retry, not ready and return 0 - if (retry > 0) { + if (secure_connect_retry_ > 0) { LOG((CLOG_DEBUG2 "retry connect secure socket")); m_secureReady = false; ARCH->sleep(s_retryDelay); return 0; } - retry = 0; + secure_connect_retry_ = 0; // No error, set ready, process and return ok m_secureReady = true; if (verifyCertFingerprint()) { @@ -522,6 +532,7 @@ SecureSocket::secureConnect(int socket) bool SecureSocket::showCertificate() { + // ssl_mutex_ is assumed to be acquired X509* cert; char* line; @@ -544,6 +555,8 @@ SecureSocket::showCertificate() void SecureSocket::checkResult(int status, int& retry) { + // ssl_mutex_ is assumed to be acquired + // ssl errors are a little quirky. the "want" errors are normal and // should result in a retry. @@ -680,6 +693,8 @@ void SecureSocket::formatFingerprint(std::string& fingerprint, bool hex, bool se bool SecureSocket::verifyCertFingerprint() { + // ssl_mutex_ is assumed to be acquired + // calculate received certificate fingerprint X509 *cert = cert = SSL_get_peer_certificate(m_ssl->m_ssl); EVP_MD* tempDigest; @@ -822,6 +837,8 @@ showCipherStackDesc(STACK_OF(SSL_CIPHER) * stack) { void SecureSocket::showSecureCipherInfo() { + // ssl_mutex_ is assumed to be acquired + STACK_OF(SSL_CIPHER) * sStack = SSL_get_ciphers(m_ssl->m_ssl); if (sStack == NULL) { @@ -864,6 +881,8 @@ SecureSocket::showSecureLibInfo() void SecureSocket::showSecureConnectInfo() { + // ssl_mutex_ is assumed to be acquired + const SSL_CIPHER* cipher = SSL_get_current_cipher(m_ssl->m_ssl); if (cipher != NULL) { diff --git a/src/lib/net/SecureSocket.h b/src/lib/net/SecureSocket.h index c602e2da..0d9d1f21 100644 --- a/src/lib/net/SecureSocket.h +++ b/src/lib/net/SecureSocket.h @@ -19,6 +19,7 @@ #include "net/TCPSocket.h" #include "net/XSocket.h" +#include class IEventQueue; class SocketMultiplexer; @@ -59,31 +60,48 @@ public: private: // SSL - void initContext(bool server); - void createSSL(); + void initContext(bool server); // may only be called with ssl_mutex_ acquired + void createSSL(); // may only be called with ssl_mutex_ acquired. int secureAccept(int s); int secureConnect(int s); - bool showCertificate(); - void checkResult(int n, int& retry); + bool showCertificate(); // may only be called with ssl_mutex_ acquired + void checkResult(int n, int& retry); // may only be called with m_ssl_mutex_ acquired. void showError(const char* reason = NULL); std::string getError(); void disconnect(); void formatFingerprint(std::string& fingerprint, bool hex = true, bool separator = true); - bool verifyCertFingerprint(); + bool verifyCertFingerprint(); // may only be called with ssl_mutex_ acquired MultiplexerJobStatus serviceConnect(ISocketMultiplexerJob*, bool, bool, bool); MultiplexerJobStatus serviceAccept(ISocketMultiplexerJob*, bool, bool, bool); - void showSecureConnectInfo(); - void showSecureLibInfo(); - void showSecureCipherInfo(); - + void showSecureConnectInfo(); // may only be called with ssl_mutex_ acquired + void showSecureLibInfo(); + void showSecureCipherInfo(); // may only be called with ssl_mutex_ acquired + void handleTCPConnected(const Event& event, void*); void freeSSLResources(); private: + // all accesses to m_ssl must be protected by this mutex. The only function that is called + // from outside SocketMultiplexer thread is close(), so we mostly care about things accessed + // by it. + std::mutex ssl_mutex_; + Ssl* m_ssl; bool m_secureReady; bool m_fatal; + + int secure_accept_retry_ = 0; // used only in secureAccept() + int secure_connect_retry_ = 0; // used only in secureConnect() + int secure_read_retry_ = 0; // used only in secureRead() + int secure_write_retry_ = 0; // used only in secureWrite() + + // The following are used only from doWrite() + // FIXME: using std::vector would simplify logic significantly. + bool do_write_retry_ = false; + int do_write_retry_size_ = 0; + std::unique_ptr do_write_retry_buffer_; + std::size_t do_write_retry_buffer_size_ = 0; }; diff --git a/src/lib/net/TCPSocket.cpp b/src/lib/net/TCPSocket.cpp index 09a8f17e..ddc0e6f9 100644 --- a/src/lib/net/TCPSocket.cpp +++ b/src/lib/net/TCPSocket.cpp @@ -33,9 +33,7 @@ #include #include -// -// TCPSocket -// +static const std::size_t MAX_INPUT_BUFFER_SIZE = 1024 * 1024; TCPSocket::TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer, IArchNetwork::EAddressFamily family) : IDataSocket(events), @@ -345,6 +343,10 @@ TCPSocket::doRead() do { m_inputBuffer.write(buffer, (UInt32)bytesRead); + if (m_inputBuffer.getSize() > MAX_INPUT_BUFFER_SIZE) { + break; + } + bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); } while (bytesRead > 0); diff --git a/src/lib/server/ClientListener.cpp b/src/lib/server/ClientListener.cpp index 00067bab..a29ae26f 100644 --- a/src/lib/server/ClientListener.cpp +++ b/src/lib/server/ClientListener.cpp @@ -184,7 +184,6 @@ ClientListener::handleUnknownClient(const Event&, void* vclient) // get the real client proxy and install it ClientProxy* client = unknownClient->orphanClientProxy(); - bool handshakeOk = true; if (client != NULL) { // handshake was successful m_waitingClients.push_back(client); @@ -196,20 +195,17 @@ ClientListener::handleUnknownClient(const Event&, void* vclient) new TMethodEventJob(this, &ClientListener::handleClientDisconnected, client)); - } - else { - handshakeOk = false; + } else { + auto* stream = unknownClient->getStream(); + if (stream) { + stream->close(); + } } // now finished with unknown client m_events->removeHandler(m_events->forClientProxyUnknown().success(), client); m_events->removeHandler(m_events->forClientProxyUnknown().failure(), client); m_newClients.erase(unknownClient); - PacketStreamFilter* streamFileter = dynamic_cast(unknownClient->getStream()); - IDataSocket* socket = NULL; - if (streamFileter != NULL) { - socket = dynamic_cast(streamFileter->getStream()); - } delete unknownClient; } diff --git a/src/lib/server/ClientProxy1_0.cpp b/src/lib/server/ClientProxy1_0.cpp index 5cbaac2a..079d2283 100644 --- a/src/lib/server/ClientProxy1_0.cpp +++ b/src/lib/server/ClientProxy1_0.cpp @@ -51,6 +51,10 @@ ClientProxy1_0::ClientProxy1_0(const std::string& name, barrier::IStream* stream stream->getEventTarget(), new TMethodEventJob(this, &ClientProxy1_0::handleDisconnect, NULL)); + m_events->adoptHandler(m_events->forIStream().inputFormatError(), + stream->getEventTarget(), + new TMethodEventJob(this, + &ClientProxy1_0::handleDisconnect, NULL)); m_events->adoptHandler(m_events->forIStream().outputShutdown(), stream->getEventTarget(), new TMethodEventJob(this, @@ -90,6 +94,8 @@ ClientProxy1_0::removeHandlers() getStream()->getEventTarget()); m_events->removeHandler(m_events->forIStream().outputShutdown(), getStream()->getEventTarget()); + m_events->removeHandler(m_events->forIStream().inputFormatError(), + getStream()->getEventTarget()); m_events->removeHandler(Event::kTimer, this); // remove timer @@ -148,9 +154,18 @@ ClientProxy1_0::handleData(const Event&, void*) } // parse message - LOG((CLOG_DEBUG2 "msg from \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); - if (!(this->*m_parser)(code)) { - LOG((CLOG_ERR "invalid message from client \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); + try { + LOG((CLOG_DEBUG2 "msg from \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); + if (!(this->*m_parser)(code)) { + LOG((CLOG_ERR "invalid message from client \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); + disconnect(); + return; + } + } catch (const XBadClient& e) { + // TODO: disconnect handling is currently dispersed across both parseMessage() and + // handleData() functions, we should collect that to a single place + + LOG((CLOG_ERR "protocol error from client: %s", e.what())); disconnect(); return; } diff --git a/src/lib/server/ClientProxyUnknown.cpp b/src/lib/server/ClientProxyUnknown.cpp index dc79da7d..de6e233e 100644 --- a/src/lib/server/ClientProxyUnknown.cpp +++ b/src/lib/server/ClientProxyUnknown.cpp @@ -118,6 +118,10 @@ ClientProxyUnknown::addStreamHandlers() m_stream->getEventTarget(), new TMethodEventJob(this, &ClientProxyUnknown::handleDisconnect)); + m_events->adoptHandler(m_events->forIStream().inputFormatError(), + m_stream->getEventTarget(), + new TMethodEventJob(this, + &ClientProxyUnknown::handleDisconnect)); m_events->adoptHandler(m_events->forIStream().outputShutdown(), m_stream->getEventTarget(), new TMethodEventJob(this, @@ -149,6 +153,8 @@ ClientProxyUnknown::removeHandlers() m_stream->getEventTarget()); m_events->removeHandler(m_events->forIStream().inputShutdown(), m_stream->getEventTarget()); + m_events->removeHandler(m_events->forIStream().inputFormatError(), + m_stream->getEventTarget()); m_events->removeHandler(m_events->forIStream().outputShutdown(), m_stream->getEventTarget()); }