Merge pull request #1356 from p12tic/2.3-security-fixes

Backports of security fixes to 2.3.x
This commit is contained in:
Povilas Kanapickas 2021-11-01 19:28:18 +02:00 committed by GitHub
commit dcbd1f91b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 199 additions and 77 deletions

View File

@ -0,0 +1,6 @@
SECURITY ISSUE
Barrier will now correctly close connections when the app-level handshake fails (fixes CVE-2021-42075).
Previously repeated failing connections would leak file descriptors leading to Barrier being unable
to receive new connections from clients.

View File

@ -0,0 +1,6 @@
SECURITY ISSUE
Barrier will now enforce a maximum length of input messages (fixes CVE-2021-42076).
Previously it was possible for a malicious client or server to send excessive length messages
leading to denial of service by resource exhaustion.

View File

@ -0,0 +1,4 @@
SECURITY ISSUE
Fixed a bug which caused Barrier to crash when disconnecting a TCP session just after sending Hello message.
This bug allowed an unauthenticated attacker to crash Barrier with only network access.

View File

@ -0,0 +1,2 @@
Fixed a bug in SSL implementation that caused invalid data occasionally being sent to clients
under heavy load.

View File

@ -17,6 +17,7 @@
*/ */
#include "barrier/PacketStreamFilter.h" #include "barrier/PacketStreamFilter.h"
#include "barrier/protocol_types.h"
#include "base/IEventQueue.h" #include "base/IEventQueue.h"
#include "mt/Lock.h" #include "mt/Lock.h"
#include "base/TMethodEventJob.h" #include "base/TMethodEventJob.h"
@ -133,8 +134,7 @@ PacketStreamFilter::isReadyNoLock() const
return (m_size != 0 && m_buffer.getSize() >= m_size); return (m_size != 0 && m_buffer.getSize() >= m_size);
} }
void bool PacketStreamFilter::readPacketSize()
PacketStreamFilter::readPacketSize()
{ {
// note -- m_mutex must be locked on entry // note -- m_mutex must be locked on entry
@ -146,7 +146,13 @@ PacketStreamFilter::readPacketSize()
((UInt32)buffer[1] << 16) | ((UInt32)buffer[1] << 16) |
((UInt32)buffer[2] << 8) | ((UInt32)buffer[2] << 8) |
(UInt32)buffer[3]; (UInt32)buffer[3];
if (m_size > PROTOCOL_MAX_MESSAGE_LENGTH) {
m_events->addEvent(Event(m_events->forIStream().inputFormatError(), getEventTarget()));
return false;
} }
}
return true;
} }
bool bool
@ -160,12 +166,16 @@ PacketStreamFilter::readMore()
UInt32 n = getStream()->read(buffer, sizeof(buffer)); UInt32 n = getStream()->read(buffer, sizeof(buffer));
while (n > 0) { while (n > 0) {
m_buffer.write(buffer, n); m_buffer.write(buffer, n);
n = getStream()->read(buffer, sizeof(buffer));
// if we don't yet have the next packet size then get it, if possible.
// Note that we can't wait for whole pending data to arrive because it may be huge in
// case of malicious or erroneous peer.
if (!readPacketSize()) {
break;
} }
// if we don't yet have the next packet size then get it, n = getStream()->read(buffer, sizeof(buffer));
// if possible. }
readPacketSize();
// note if we now have a whole packet // note if we now have a whole packet
bool isReady = isReadyNoLock(); bool isReady = isReadyNoLock();

View File

@ -47,7 +47,9 @@ protected:
private: private:
bool isReadyNoLock() const; bool isReadyNoLock() const;
void readPacketSize();
// returns false on erroneous packet size
bool readPacketSize();
bool readMore(); bool readMore();
private: private:

View File

@ -19,6 +19,8 @@
#include "barrier/ProtocolUtil.h" #include "barrier/ProtocolUtil.h"
#include "io/IStream.h" #include "io/IStream.h"
#include "base/Log.h" #include "base/Log.h"
#include "barrier/protocol_types.h"
#include "barrier/XBarrier.h"
#include "common/stdvector.h" #include "common/stdvector.h"
#include "base/String.h" #include "base/String.h"
@ -159,6 +161,10 @@ ProtocolUtil::vreadf(barrier::IStream* stream, const char* fmt, va_list args)
(static_cast<UInt32>(buffer[2]) << 8) | (static_cast<UInt32>(buffer[2]) << 8) |
static_cast<UInt32>(buffer[3]); static_cast<UInt32>(buffer[3]);
if (n > PROTOCOL_MAX_LIST_LENGTH) {
throw XBadClient("Too long message received");
}
// convert it // convert it
void* v = va_arg(args, void*); void* v = va_arg(args, void*);
switch (len) { switch (len) {
@ -211,6 +217,10 @@ ProtocolUtil::vreadf(barrier::IStream* stream, const char* fmt, va_list args)
(static_cast<UInt32>(buffer[2]) << 8) | (static_cast<UInt32>(buffer[2]) << 8) |
static_cast<UInt32>(buffer[3]); static_cast<UInt32>(buffer[3]);
if (len > PROTOCOL_MAX_STRING_LENGTH) {
throw XBadClient("Too long message received");
}
// use a fixed size buffer if its big enough // use a fixed size buffer if its big enough
const bool useFixed = (len <= sizeof(buffer)); const bool useFixed = (len <= sizeof(buffer));

View File

@ -20,6 +20,8 @@
#include "base/EventTypes.h" #include "base/EventTypes.h"
#include <cstdint>
// protocol version number // protocol version number
// 1.0: initial protocol // 1.0: initial protocol
// 1.1: adds KeyCode to key press, release, and repeat // 1.1: adds KeyCode to key press, release, and repeat
@ -51,6 +53,12 @@ static const double kKeepAlivesUntilDeath = 3.0;
static const double kHeartRate = -1.0; static const double kHeartRate = -1.0;
static const double kHeartBeatsUntilDeath = 3.0; static const double kHeartBeatsUntilDeath = 3.0;
// Messages of very large size indicate a likely protocol error. We don't parse such messages and
// drop connection instead. Note that e.g. the clipboard messages are already limited to 32kB.
static constexpr std::uint32_t PROTOCOL_MAX_MESSAGE_LENGTH = 4 * 1024 * 1024;
static constexpr std::uint32_t PROTOCOL_MAX_LIST_LENGTH = 1024 * 1024;
static constexpr std::uint32_t PROTOCOL_MAX_STRING_LENGTH = 1024 * 1024;
// direction constants // direction constants
enum EDirection { enum EDirection {
kNoDirection, kNoDirection,

View File

@ -56,6 +56,7 @@ REGISTER_EVENT(IStream, outputFlushed)
REGISTER_EVENT(IStream, outputError) REGISTER_EVENT(IStream, outputError)
REGISTER_EVENT(IStream, inputShutdown) REGISTER_EVENT(IStream, inputShutdown)
REGISTER_EVENT(IStream, outputShutdown) REGISTER_EVENT(IStream, outputShutdown)
REGISTER_EVENT(IStream, inputFormatError)
// //
// IpcClient // IpcClient

View File

@ -133,6 +133,11 @@ public:
*/ */
Event::Type outputShutdown(); Event::Type outputShutdown();
/** Get input format error event type
This is sent when a stream receives an irrecoverable input format error.
*/
Event::Type inputFormatError();
//@} //@}
private: private:
@ -141,6 +146,7 @@ private:
Event::Type m_outputError; Event::Type m_outputError;
Event::Type m_inputShutdown; Event::Type m_inputShutdown;
Event::Type m_outputShutdown; Event::Type m_outputShutdown;
Event::Type m_inputFormatError;
}; };
class IpcClientEvents : public EventTypes { class IpcClientEvents : public EventTypes {

View File

@ -26,6 +26,7 @@
#include "barrier/ProtocolUtil.h" #include "barrier/ProtocolUtil.h"
#include "barrier/option_types.h" #include "barrier/option_types.h"
#include "barrier/protocol_types.h" #include "barrier/protocol_types.h"
#include "barrier/XBarrier.h"
#include "io/IStream.h" #include "io/IStream.h"
#include "base/Log.h" #include "base/Log.h"
#include "base/IEventQueue.h" #include "base/IEventQueue.h"
@ -124,6 +125,7 @@ ServerProxy::handleData(const Event&, void*)
// parse message // parse message
LOG((CLOG_DEBUG2 "msg from server: %c%c%c%c", code[0], code[1], code[2], code[3])); LOG((CLOG_DEBUG2 "msg from server: %c%c%c%c", code[0], code[1], code[2], code[3]));
try {
switch ((this->*m_parser)(code)) { switch ((this->*m_parser)(code)) {
case kOkay: case kOkay:
break; break;
@ -136,6 +138,15 @@ ServerProxy::handleData(const Event&, void*)
case kDisconnect: case kDisconnect:
return; return;
} }
} catch (const XBadClient& e) {
// TODO: disconnect handling is currently dispersed across both parseMessage() and
// handleData() functions, we should collect that to a single place
LOG((CLOG_ERR "protocol error from server: %s", e.what()));
ProtocolUtil::writef(m_stream, kMsgEBad);
m_client->disconnect("invalid message from server");
return;
}
// next message // next message
n = m_stream->read(code, 4); n = m_stream->read(code, 4);

View File

@ -40,6 +40,7 @@
#define MAX_ERROR_SIZE 65535 #define MAX_ERROR_SIZE 65535
static const std::size_t MAX_INPUT_BUFFER_SIZE = 1024 * 1024;
static const float s_retryDelay = 0.01f; static const float s_retryDelay = 0.01f;
enum { enum {
@ -103,6 +104,8 @@ SecureSocket::close()
void SecureSocket::freeSSLResources() void SecureSocket::freeSSLResources()
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
if (m_ssl->m_ssl != NULL) { if (m_ssl->m_ssl != NULL) {
SSL_shutdown(m_ssl->m_ssl); SSL_shutdown(m_ssl->m_ssl);
SSL_free(m_ssl->m_ssl); SSL_free(m_ssl->m_ssl);
@ -156,7 +159,7 @@ SecureSocket::secureAccept()
TCPSocket::EJobResult TCPSocket::EJobResult
SecureSocket::doRead() SecureSocket::doRead()
{ {
static UInt8 buffer[4096]; UInt8 buffer[4096];
memset(buffer, 0, sizeof(buffer)); memset(buffer, 0, sizeof(buffer));
int bytesRead = 0; int bytesRead = 0;
int status = 0; int status = 0;
@ -181,6 +184,10 @@ SecureSocket::doRead()
do { do {
m_inputBuffer.write(buffer, bytesRead); m_inputBuffer.write(buffer, bytesRead);
if (m_inputBuffer.getSize() > MAX_INPUT_BUFFER_SIZE) {
break;
}
status = secureRead(buffer, sizeof(buffer), bytesRead); status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) { if (status < 0) {
return kBreak; return kBreak;
@ -211,11 +218,6 @@ SecureSocket::doRead()
TCPSocket::EJobResult TCPSocket::EJobResult
SecureSocket::doWrite() SecureSocket::doWrite()
{ {
static bool s_retry = false;
static int s_retrySize = 0;
static std::unique_ptr<char[]> s_staticBuffer;
static std::size_t s_staticBufferSize = 0;
// write data // write data
int bufferSize = 0; int bufferSize = 0;
int bytesWrote = 0; int bytesWrote = 0;
@ -224,16 +226,16 @@ SecureSocket::doWrite()
if (!isSecureReady()) if (!isSecureReady())
return kRetry; return kRetry;
if (s_retry) { if (do_write_retry_) {
bufferSize = s_retrySize; bufferSize = do_write_retry_size_;
} else { } else {
bufferSize = m_outputBuffer.getSize(); bufferSize = m_outputBuffer.getSize();
if (bufferSize > s_staticBufferSize) { if (bufferSize > do_write_retry_buffer_size_) {
s_staticBuffer.reset(new char[bufferSize]); do_write_retry_buffer_.reset(new char[bufferSize]);
s_staticBufferSize = bufferSize; do_write_retry_buffer_size_ = bufferSize;
} }
if (bufferSize > 0) { if (bufferSize > 0) {
memcpy(s_staticBuffer.get(), m_outputBuffer.peek(bufferSize), bufferSize); std::memcpy(do_write_retry_buffer_.get(), m_outputBuffer.peek(bufferSize), bufferSize);
} }
} }
@ -241,14 +243,14 @@ SecureSocket::doWrite()
return kRetry; return kRetry;
} }
status = secureWrite(s_staticBuffer.get(), bufferSize, bytesWrote); status = secureWrite(do_write_retry_buffer_.get(), bufferSize, bytesWrote);
if (status > 0) { if (status > 0) {
s_retry = false; do_write_retry_ = false;
} else if (status < 0) { } else if (status < 0) {
return kBreak; return kBreak;
} else if (status == 0) { } else if (status == 0) {
s_retry = true; do_write_retry_ = true;
s_retrySize = bufferSize; do_write_retry_size_ = bufferSize;
return kNew; return kNew;
} }
@ -263,16 +265,16 @@ SecureSocket::doWrite()
int int
SecureSocket::secureRead(void* buffer, int size, int& read) SecureSocket::secureRead(void* buffer, int size, int& read)
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
if (m_ssl->m_ssl != NULL) { if (m_ssl->m_ssl != NULL) {
LOG((CLOG_DEBUG2 "reading secure socket")); LOG((CLOG_DEBUG2 "reading secure socket"));
read = SSL_read(m_ssl->m_ssl, buffer, size); read = SSL_read(m_ssl->m_ssl, buffer, size);
static int retry;
// Check result will cleanup the connection in the case of a fatal // Check result will cleanup the connection in the case of a fatal
checkResult(read, retry); checkResult(read, secure_read_retry_);
if (retry) { if (secure_read_retry_) {
return 0; return 0;
} }
@ -289,17 +291,17 @@ SecureSocket::secureRead(void* buffer, int size, int& read)
int int
SecureSocket::secureWrite(const void* buffer, int size, int& wrote) SecureSocket::secureWrite(const void* buffer, int size, int& wrote)
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
if (m_ssl->m_ssl != NULL) { if (m_ssl->m_ssl != NULL) {
LOG((CLOG_DEBUG2 "writing secure socket:%p", this)); LOG((CLOG_DEBUG2 "writing secure socket:%p", this));
wrote = SSL_write(m_ssl->m_ssl, buffer, size); wrote = SSL_write(m_ssl->m_ssl, buffer, size);
static int retry;
// Check result will cleanup the connection in the case of a fatal // Check result will cleanup the connection in the case of a fatal
checkResult(wrote, retry); checkResult(wrote, secure_write_retry_);
if (retry) { if (secure_write_retry_) {
return 0; return 0;
} }
@ -322,6 +324,8 @@ SecureSocket::isSecureReady()
void void
SecureSocket::initSsl(bool server) SecureSocket::initSsl(bool server)
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
m_ssl = new Ssl(); m_ssl = new Ssl();
m_ssl->m_context = NULL; m_ssl->m_context = NULL;
m_ssl->m_ssl = NULL; m_ssl->m_ssl = NULL;
@ -331,6 +335,8 @@ SecureSocket::initSsl(bool server)
bool SecureSocket::loadCertificates(std::string& filename) bool SecureSocket::loadCertificates(std::string& filename)
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
if (filename.empty()) { if (filename.empty()) {
showError("ssl certificate is not specified"); showError("ssl certificate is not specified");
return false; return false;
@ -373,6 +379,8 @@ bool SecureSocket::loadCertificates(std::string& filename)
void void
SecureSocket::initContext(bool server) SecureSocket::initContext(bool server)
{ {
// ssl_mutex_ is assumed to be acquired
SSL_library_init(); SSL_library_init();
const SSL_METHOD* method; const SSL_METHOD* method;
@ -410,6 +418,8 @@ SecureSocket::initContext(bool server)
void void
SecureSocket::createSSL() SecureSocket::createSSL()
{ {
// ssl_mutex_ is assumed to be acquired
// I assume just one instance is needed // I assume just one instance is needed
// get new SSL state with context // get new SSL state with context
if (m_ssl->m_ssl == NULL) { if (m_ssl->m_ssl == NULL) {
@ -421,6 +431,8 @@ SecureSocket::createSSL()
int int
SecureSocket::secureAccept(int socket) SecureSocket::secureAccept(int socket)
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
createSSL(); createSSL();
// set connection socket to SSL state // set connection socket to SSL state
@ -429,9 +441,7 @@ SecureSocket::secureAccept(int socket)
LOG((CLOG_DEBUG2 "accepting secure socket")); LOG((CLOG_DEBUG2 "accepting secure socket"));
int r = SSL_accept(m_ssl->m_ssl); int r = SSL_accept(m_ssl->m_ssl);
static int retry; checkResult(r, secure_accept_retry_);
checkResult(r, retry);
if (isFatal()) { if (isFatal()) {
// tell user and sleep so the socket isn't hammered. // tell user and sleep so the socket isn't hammered.
@ -439,12 +449,12 @@ SecureSocket::secureAccept(int socket)
LOG((CLOG_INFO "client connection may not be secure")); LOG((CLOG_INFO "client connection may not be secure"));
m_secureReady = false; m_secureReady = false;
ARCH->sleep(1); ARCH->sleep(1);
retry = 0; secure_accept_retry_ = 0;
return -1; // Failed, error out return -1; // Failed, error out
} }
// If not fatal and no retry, state is good // If not fatal and no retry, state is good
if (retry == 0) { if (secure_accept_retry_ == 0) {
m_secureReady = true; m_secureReady = true;
LOG((CLOG_INFO "accepted secure socket")); LOG((CLOG_INFO "accepted secure socket"));
if (CLOG->getFilter() >= kDEBUG1) { if (CLOG->getFilter() >= kDEBUG1) {
@ -455,7 +465,7 @@ SecureSocket::secureAccept(int socket)
} }
// If not fatal and retry is set, not ready, and return retry // If not fatal and retry is set, not ready, and return retry
if (retry > 0) { if (secure_accept_retry_ > 0) {
LOG((CLOG_DEBUG2 "retry accepting secure socket")); LOG((CLOG_DEBUG2 "retry accepting secure socket"));
m_secureReady = false; m_secureReady = false;
ARCH->sleep(s_retryDelay); ARCH->sleep(s_retryDelay);
@ -470,6 +480,8 @@ SecureSocket::secureAccept(int socket)
int int
SecureSocket::secureConnect(int socket) SecureSocket::secureConnect(int socket)
{ {
std::lock_guard<std::mutex> ssl_lock{ssl_mutex_};
createSSL(); createSSL();
// attach the socket descriptor // attach the socket descriptor
@ -478,25 +490,23 @@ SecureSocket::secureConnect(int socket)
LOG((CLOG_DEBUG2 "connecting secure socket")); LOG((CLOG_DEBUG2 "connecting secure socket"));
int r = SSL_connect(m_ssl->m_ssl); int r = SSL_connect(m_ssl->m_ssl);
static int retry; checkResult(r, secure_connect_retry_);
checkResult(r, retry);
if (isFatal()) { if (isFatal()) {
LOG((CLOG_ERR "failed to connect secure socket")); LOG((CLOG_ERR "failed to connect secure socket"));
retry = 0; secure_connect_retry_ = 0;
return -1; return -1;
} }
// If we should retry, not ready and return 0 // If we should retry, not ready and return 0
if (retry > 0) { if (secure_connect_retry_ > 0) {
LOG((CLOG_DEBUG2 "retry connect secure socket")); LOG((CLOG_DEBUG2 "retry connect secure socket"));
m_secureReady = false; m_secureReady = false;
ARCH->sleep(s_retryDelay); ARCH->sleep(s_retryDelay);
return 0; return 0;
} }
retry = 0; secure_connect_retry_ = 0;
// No error, set ready, process and return ok // No error, set ready, process and return ok
m_secureReady = true; m_secureReady = true;
if (verifyCertFingerprint()) { if (verifyCertFingerprint()) {
@ -522,6 +532,7 @@ SecureSocket::secureConnect(int socket)
bool bool
SecureSocket::showCertificate() SecureSocket::showCertificate()
{ {
// ssl_mutex_ is assumed to be acquired
X509* cert; X509* cert;
char* line; char* line;
@ -544,6 +555,8 @@ SecureSocket::showCertificate()
void void
SecureSocket::checkResult(int status, int& retry) SecureSocket::checkResult(int status, int& retry)
{ {
// ssl_mutex_ is assumed to be acquired
// ssl errors are a little quirky. the "want" errors are normal and // ssl errors are a little quirky. the "want" errors are normal and
// should result in a retry. // should result in a retry.
@ -680,6 +693,8 @@ void SecureSocket::formatFingerprint(std::string& fingerprint, bool hex, bool se
bool bool
SecureSocket::verifyCertFingerprint() SecureSocket::verifyCertFingerprint()
{ {
// ssl_mutex_ is assumed to be acquired
// calculate received certificate fingerprint // calculate received certificate fingerprint
X509 *cert = cert = SSL_get_peer_certificate(m_ssl->m_ssl); X509 *cert = cert = SSL_get_peer_certificate(m_ssl->m_ssl);
EVP_MD* tempDigest; EVP_MD* tempDigest;
@ -822,6 +837,8 @@ showCipherStackDesc(STACK_OF(SSL_CIPHER) * stack) {
void void
SecureSocket::showSecureCipherInfo() SecureSocket::showSecureCipherInfo()
{ {
// ssl_mutex_ is assumed to be acquired
STACK_OF(SSL_CIPHER) * sStack = SSL_get_ciphers(m_ssl->m_ssl); STACK_OF(SSL_CIPHER) * sStack = SSL_get_ciphers(m_ssl->m_ssl);
if (sStack == NULL) { if (sStack == NULL) {
@ -864,6 +881,8 @@ SecureSocket::showSecureLibInfo()
void void
SecureSocket::showSecureConnectInfo() SecureSocket::showSecureConnectInfo()
{ {
// ssl_mutex_ is assumed to be acquired
const SSL_CIPHER* cipher = SSL_get_current_cipher(m_ssl->m_ssl); const SSL_CIPHER* cipher = SSL_get_current_cipher(m_ssl->m_ssl);
if (cipher != NULL) { if (cipher != NULL) {

View File

@ -19,6 +19,7 @@
#include "net/TCPSocket.h" #include "net/TCPSocket.h"
#include "net/XSocket.h" #include "net/XSocket.h"
#include <mutex>
class IEventQueue; class IEventQueue;
class SocketMultiplexer; class SocketMultiplexer;
@ -59,31 +60,48 @@ public:
private: private:
// SSL // SSL
void initContext(bool server); void initContext(bool server); // may only be called with ssl_mutex_ acquired
void createSSL(); void createSSL(); // may only be called with ssl_mutex_ acquired.
int secureAccept(int s); int secureAccept(int s);
int secureConnect(int s); int secureConnect(int s);
bool showCertificate(); bool showCertificate(); // may only be called with ssl_mutex_ acquired
void checkResult(int n, int& retry); void checkResult(int n, int& retry); // may only be called with m_ssl_mutex_ acquired.
void showError(const char* reason = NULL); void showError(const char* reason = NULL);
std::string getError(); std::string getError();
void disconnect(); void disconnect();
void formatFingerprint(std::string& fingerprint, bool hex = true, bool separator = true); void formatFingerprint(std::string& fingerprint, bool hex = true, bool separator = true);
bool verifyCertFingerprint(); bool verifyCertFingerprint(); // may only be called with ssl_mutex_ acquired
MultiplexerJobStatus serviceConnect(ISocketMultiplexerJob*, bool, bool, bool); MultiplexerJobStatus serviceConnect(ISocketMultiplexerJob*, bool, bool, bool);
MultiplexerJobStatus serviceAccept(ISocketMultiplexerJob*, bool, bool, bool); MultiplexerJobStatus serviceAccept(ISocketMultiplexerJob*, bool, bool, bool);
void showSecureConnectInfo(); void showSecureConnectInfo(); // may only be called with ssl_mutex_ acquired
void showSecureLibInfo(); void showSecureLibInfo();
void showSecureCipherInfo(); void showSecureCipherInfo(); // may only be called with ssl_mutex_ acquired
void handleTCPConnected(const Event& event, void*); void handleTCPConnected(const Event& event, void*);
void freeSSLResources(); void freeSSLResources();
private: private:
// all accesses to m_ssl must be protected by this mutex. The only function that is called
// from outside SocketMultiplexer thread is close(), so we mostly care about things accessed
// by it.
std::mutex ssl_mutex_;
Ssl* m_ssl; Ssl* m_ssl;
bool m_secureReady; bool m_secureReady;
bool m_fatal; bool m_fatal;
int secure_accept_retry_ = 0; // used only in secureAccept()
int secure_connect_retry_ = 0; // used only in secureConnect()
int secure_read_retry_ = 0; // used only in secureRead()
int secure_write_retry_ = 0; // used only in secureWrite()
// The following are used only from doWrite()
// FIXME: using std::vector would simplify logic significantly.
bool do_write_retry_ = false;
int do_write_retry_size_ = 0;
std::unique_ptr<char[]> do_write_retry_buffer_;
std::size_t do_write_retry_buffer_size_ = 0;
}; };

View File

@ -33,9 +33,7 @@
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
// static const std::size_t MAX_INPUT_BUFFER_SIZE = 1024 * 1024;
// TCPSocket
//
TCPSocket::TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer, IArchNetwork::EAddressFamily family) : TCPSocket::TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer, IArchNetwork::EAddressFamily family) :
IDataSocket(events), IDataSocket(events),
@ -345,6 +343,10 @@ TCPSocket::doRead()
do { do {
m_inputBuffer.write(buffer, (UInt32)bytesRead); m_inputBuffer.write(buffer, (UInt32)bytesRead);
if (m_inputBuffer.getSize() > MAX_INPUT_BUFFER_SIZE) {
break;
}
bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer));
} while (bytesRead > 0); } while (bytesRead > 0);

View File

@ -184,7 +184,6 @@ ClientListener::handleUnknownClient(const Event&, void* vclient)
// get the real client proxy and install it // get the real client proxy and install it
ClientProxy* client = unknownClient->orphanClientProxy(); ClientProxy* client = unknownClient->orphanClientProxy();
bool handshakeOk = true;
if (client != NULL) { if (client != NULL) {
// handshake was successful // handshake was successful
m_waitingClients.push_back(client); m_waitingClients.push_back(client);
@ -196,20 +195,17 @@ ClientListener::handleUnknownClient(const Event&, void* vclient)
new TMethodEventJob<ClientListener>(this, new TMethodEventJob<ClientListener>(this,
&ClientListener::handleClientDisconnected, &ClientListener::handleClientDisconnected,
client)); client));
} else {
auto* stream = unknownClient->getStream();
if (stream) {
stream->close();
} }
else {
handshakeOk = false;
} }
// now finished with unknown client // now finished with unknown client
m_events->removeHandler(m_events->forClientProxyUnknown().success(), client); m_events->removeHandler(m_events->forClientProxyUnknown().success(), client);
m_events->removeHandler(m_events->forClientProxyUnknown().failure(), client); m_events->removeHandler(m_events->forClientProxyUnknown().failure(), client);
m_newClients.erase(unknownClient); m_newClients.erase(unknownClient);
PacketStreamFilter* streamFileter = dynamic_cast<PacketStreamFilter*>(unknownClient->getStream());
IDataSocket* socket = NULL;
if (streamFileter != NULL) {
socket = dynamic_cast<IDataSocket*>(streamFileter->getStream());
}
delete unknownClient; delete unknownClient;
} }

View File

@ -51,6 +51,10 @@ ClientProxy1_0::ClientProxy1_0(const std::string& name, barrier::IStream* stream
stream->getEventTarget(), stream->getEventTarget(),
new TMethodEventJob<ClientProxy1_0>(this, new TMethodEventJob<ClientProxy1_0>(this,
&ClientProxy1_0::handleDisconnect, NULL)); &ClientProxy1_0::handleDisconnect, NULL));
m_events->adoptHandler(m_events->forIStream().inputFormatError(),
stream->getEventTarget(),
new TMethodEventJob<ClientProxy1_0>(this,
&ClientProxy1_0::handleDisconnect, NULL));
m_events->adoptHandler(m_events->forIStream().outputShutdown(), m_events->adoptHandler(m_events->forIStream().outputShutdown(),
stream->getEventTarget(), stream->getEventTarget(),
new TMethodEventJob<ClientProxy1_0>(this, new TMethodEventJob<ClientProxy1_0>(this,
@ -90,6 +94,8 @@ ClientProxy1_0::removeHandlers()
getStream()->getEventTarget()); getStream()->getEventTarget());
m_events->removeHandler(m_events->forIStream().outputShutdown(), m_events->removeHandler(m_events->forIStream().outputShutdown(),
getStream()->getEventTarget()); getStream()->getEventTarget());
m_events->removeHandler(m_events->forIStream().inputFormatError(),
getStream()->getEventTarget());
m_events->removeHandler(Event::kTimer, this); m_events->removeHandler(Event::kTimer, this);
// remove timer // remove timer
@ -148,12 +154,21 @@ ClientProxy1_0::handleData(const Event&, void*)
} }
// parse message // parse message
try {
LOG((CLOG_DEBUG2 "msg from \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); LOG((CLOG_DEBUG2 "msg from \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3]));
if (!(this->*m_parser)(code)) { if (!(this->*m_parser)(code)) {
LOG((CLOG_ERR "invalid message from client \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); LOG((CLOG_ERR "invalid message from client \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3]));
disconnect(); disconnect();
return; return;
} }
} catch (const XBadClient& e) {
// TODO: disconnect handling is currently dispersed across both parseMessage() and
// handleData() functions, we should collect that to a single place
LOG((CLOG_ERR "protocol error from client: %s", e.what()));
disconnect();
return;
}
// next message // next message
n = getStream()->read(code, 4); n = getStream()->read(code, 4);

View File

@ -118,6 +118,10 @@ ClientProxyUnknown::addStreamHandlers()
m_stream->getEventTarget(), m_stream->getEventTarget(),
new TMethodEventJob<ClientProxyUnknown>(this, new TMethodEventJob<ClientProxyUnknown>(this,
&ClientProxyUnknown::handleDisconnect)); &ClientProxyUnknown::handleDisconnect));
m_events->adoptHandler(m_events->forIStream().inputFormatError(),
m_stream->getEventTarget(),
new TMethodEventJob<ClientProxyUnknown>(this,
&ClientProxyUnknown::handleDisconnect));
m_events->adoptHandler(m_events->forIStream().outputShutdown(), m_events->adoptHandler(m_events->forIStream().outputShutdown(),
m_stream->getEventTarget(), m_stream->getEventTarget(),
new TMethodEventJob<ClientProxyUnknown>(this, new TMethodEventJob<ClientProxyUnknown>(this,
@ -149,6 +153,8 @@ ClientProxyUnknown::removeHandlers()
m_stream->getEventTarget()); m_stream->getEventTarget());
m_events->removeHandler(m_events->forIStream().inputShutdown(), m_events->removeHandler(m_events->forIStream().inputShutdown(),
m_stream->getEventTarget()); m_stream->getEventTarget());
m_events->removeHandler(m_events->forIStream().inputFormatError(),
m_stream->getEventTarget());
m_events->removeHandler(m_events->forIStream().outputShutdown(), m_events->removeHandler(m_events->forIStream().outputShutdown(),
m_stream->getEventTarget()); m_stream->getEventTarget());
} }