diff --git a/src/lib/client/Client.cpp b/src/lib/client/Client.cpp index 9227834b..2dd9819a 100644 --- a/src/lib/client/Client.cpp +++ b/src/lib/client/Client.cpp @@ -32,6 +32,7 @@ #include "mt/Thread.h" #include "io/IStreamFilterFactory.h" #include "io/CryptoStream.h" +#include "net/TCPSocket.h" #include "net/IDataSocket.h" #include "net/ISocketFactory.h" #include "arch/Arch.h" @@ -82,7 +83,8 @@ Client::Client( m_crypto(crypto), m_sendFileThread(NULL), m_writeToDropDirThread(NULL), - m_enableDragDrop(enableDragDrop) + m_enableDragDrop(enableDragDrop), + m_socket(NULL) { assert(m_socketFactory != NULL); assert(m_screen != NULL); @@ -159,6 +161,7 @@ Client::connect() // create the socket bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity); IDataSocket* socket = m_socketFactory->create(useSecureSocket); + m_socket = dynamic_cast(socket); // filter socket messages, including a packetizing filter m_stream = socket; @@ -594,6 +597,8 @@ Client::handleConnected(const Event&, void*) m_sentClipboard[id] = false; m_timeClipboard[id] = 0; } + + m_socket->secureConnect(); } void diff --git a/src/lib/client/Client.h b/src/lib/client/Client.h index 37c7f770..f3437279 100644 --- a/src/lib/client/Client.h +++ b/src/lib/client/Client.h @@ -37,7 +37,7 @@ class IStreamFilterFactory; class IEventQueue; class CryptoStream; class Thread; -class SecureSocket; +class TCPSocket; //! Synergy client /*! @@ -236,4 +236,5 @@ private: Thread* m_sendFileThread; Thread* m_writeToDropDirThread; bool m_enableDragDrop; + TCPSocket* m_socket; }; diff --git a/src/lib/net/TCPSocket.cpp b/src/lib/net/TCPSocket.cpp index 2b2fd1cf..c54df40b 100644 --- a/src/lib/net/TCPSocket.cpp +++ b/src/lib/net/TCPSocket.cpp @@ -467,7 +467,17 @@ TCPSocket::serviceConnected(ISocketMultiplexerJob* job, // write data UInt32 n = m_outputBuffer.getSize(); const void* buffer = m_outputBuffer.peek(n); - n = (UInt32)ARCH->writeSocket(m_socket, buffer, n); + if (isSecure()) { + if (isSecureReady()) { + n = secureWrite(buffer, n); + } + else { + return job; + } + } + else { + n = (UInt32)ARCH->writeSocket(m_socket, buffer, n); + } // discard written data if (n > 0) { @@ -510,14 +520,34 @@ TCPSocket::serviceConnected(ISocketMultiplexerJob* job, if (read && m_readable) { try { UInt8 buffer[4096]; - size_t n = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); + size_t n = 0; + + if (isSecure()) { + if (isSecureReady()) { + n = secureRead(buffer, sizeof(buffer)); + } + else { + return job; + } + } + else { + n = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); + } + if (n > 0) { bool wasEmpty = (m_inputBuffer.getSize() == 0); // slurp up as much as possible do { m_inputBuffer.write(buffer, (UInt32)n); - n = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); + + if (isSecure() && isSecureReady()) { + n = secureRead(buffer, sizeof(buffer)); + } + else { + n = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); + } + } while (n > 0); // send input ready if input buffer was empty diff --git a/src/lib/net/TCPSocket.h b/src/lib/net/TCPSocket.h index a026b608..cece95fd 100644 --- a/src/lib/net/TCPSocket.h +++ b/src/lib/net/TCPSocket.h @@ -57,19 +57,31 @@ public: // IDataSocket overrides virtual void connect(const NetworkAddress&); + virtual void secureConnect() {} + virtual void secureAccept() {} + protected: - virtual void onConnected(); ArchSocket getSocket() { return m_socket; } -private: - void init(); + virtual bool isSecureReady() { return false; } + virtual bool isSecure() { return false; } + virtual UInt32 secureRead(void* buffer, UInt32) { return 0; } + virtual UInt32 secureWrite(const void*, UInt32) { return 0; } void setJob(ISocketMultiplexerJob*); ISocketMultiplexerJob* newJob(); + bool isReadable() { return m_readable; } + bool isWritable() { return m_writable; } + + Mutex& getMutex() { return m_mutex; } + +private: + void init(); + void sendConnectionFailedEvent(const char*); void sendEvent(Event::Type); - + void onConnected(); void onInputShutdown(); void onOutputShutdown(); void onDisconnected(); diff --git a/src/lib/net/TCPSocketFactory.cpp b/src/lib/net/TCPSocketFactory.cpp index 2f272d94..ea638cd3 100644 --- a/src/lib/net/TCPSocketFactory.cpp +++ b/src/lib/net/TCPSocketFactory.cpp @@ -50,11 +50,9 @@ TCPSocketFactory::create(bool secure) const { IDataSocket* socket = NULL; if (secure) { - void* args[4] = { + void* args[2] = { m_events, - m_socketMultiplexer, - Log::getInstance(), - Arch::getInstance() + m_socketMultiplexer }; socket = static_cast( ARCH->plugin().invoke(s_networkSecurity, "getSecureSocket", args)); @@ -71,11 +69,9 @@ TCPSocketFactory::createListen(bool secure) const { IListenSocket* socket = NULL; if (secure) { - void* args[4] = { + void* args[2] = { m_events, - m_socketMultiplexer, - Log::getInstance(), - Arch::getInstance() + m_socketMultiplexer }; socket = static_cast( ARCH->plugin().invoke(s_networkSecurity, "getSecureListenSocket", args)); diff --git a/src/lib/plugin/ns/SecureListenSocket.cpp b/src/lib/plugin/ns/SecureListenSocket.cpp index e8c5d0c9..672669c6 100644 --- a/src/lib/plugin/ns/SecureListenSocket.cpp +++ b/src/lib/plugin/ns/SecureListenSocket.cpp @@ -22,7 +22,6 @@ #include "net/SocketMultiplexer.h" #include "net/TSocketMultiplexerMethodJob.h" #include "arch/XArch.h" -#include "base/Log.h" // // SecureListenSocket @@ -51,14 +50,11 @@ SecureListenSocket::accept() socket->initSsl(true); // TODO: customized certificate path socket->loadCertificates("C:\\Temp\\synergy.pem"); - if (socket != NULL) { m_socketMultiplexer->addSocket(this, new TSocketMultiplexerMethodJob( this, &TCPListenSocket::serviceListening, m_socket, true, false)); - - socket->acceptSecureSocket(); } return dynamic_cast(socket); } diff --git a/src/lib/plugin/ns/SecureSocket.cpp b/src/lib/plugin/ns/SecureSocket.cpp index 8c3d61c6..c4ea34e5 100644 --- a/src/lib/plugin/ns/SecureSocket.cpp +++ b/src/lib/plugin/ns/SecureSocket.cpp @@ -17,7 +17,9 @@ #include "SecureSocket.h" +#include "net/TSocketMultiplexerMethodJob.h" #include "net/TCPSocket.h" +#include "mt/Lock.h" #include "arch/XArch.h" #include "base/Log.h" @@ -42,7 +44,7 @@ SecureSocket::SecureSocket( IEventQueue* events, SocketMultiplexer* socketMultiplexer) : TCPSocket(events, socketMultiplexer), - m_ready(false) + m_secureReady(false) { } @@ -51,7 +53,7 @@ SecureSocket::SecureSocket( SocketMultiplexer* socketMultiplexer, ArchSocket socket) : TCPSocket(events, socketMultiplexer, socket), - m_ready(false) + m_secureReady(false) { } @@ -67,8 +69,24 @@ SecureSocket::~SecureSocket() delete[] m_error; } +void +SecureSocket::secureConnect() +{ + setJob(new TSocketMultiplexerMethodJob( + this, &SecureSocket::serviceConnect, + getSocket(), isReadable(), isWritable())); +} + +void +SecureSocket::secureAccept() +{ + setJob(new TSocketMultiplexerMethodJob( + this, &SecureSocket::serviceAccept, + getSocket(), isReadable(), isWritable())); +} + UInt32 -SecureSocket::read(void* buffer, UInt32 n) +SecureSocket::secureRead(void* buffer, UInt32 n) { bool retry = false; int r = 0; @@ -83,8 +101,8 @@ SecureSocket::read(void* buffer, UInt32 n) return r > 0 ? (UInt32)r : 0; } -void -SecureSocket::write(const void* buffer, UInt32 n) +UInt32 +SecureSocket::secureWrite(const void* buffer, UInt32 n) { bool retry = false; int r = 0; @@ -95,32 +113,14 @@ SecureSocket::write(const void* buffer, UInt32 n) r = 0; } } + + return r > 0 ? (UInt32)r : 0; } bool -SecureSocket::isReady() const +SecureSocket::isSecureReady() { - return m_ready; -} - -void -SecureSocket::connectSecureSocket() -{ -#ifdef SYSAPI_WIN32 - secureConnect(static_cast(getSocket()->m_socket)); -#elif SYSAPI_UNIX - secureConnect(getSocket()->m_fd); -#endif -} - -void -SecureSocket::acceptSecureSocket() -{ -#ifdef SYSAPI_WIN32 - secureAccept(static_cast(getSocket()->m_socket)); -#elif SYSAPI_UNIX - secureAccept(getSocket()->m_fd); -#endif + return m_secureReady; } void @@ -134,43 +134,6 @@ SecureSocket::initSsl(bool server) initContext(server); } -void -SecureSocket::onConnected() -{ - TCPSocket::onConnected(); - - connectSecureSocket(); -} - -void -SecureSocket::initContext(bool server) -{ - SSL_library_init(); - - const SSL_METHOD* method; - - // load & register all cryptos, etc. - OpenSSL_add_all_algorithms(); - - // load all error messages - SSL_load_error_strings(); - - if (server) { - // create new server-method instance - method = SSLv3_server_method(); - } - else { - // create new client-method instance - method = SSLv3_client_method(); - } - - // create new context from method - m_ssl->m_context = SSL_CTX_new(method); - if (m_ssl->m_context == NULL) { - showError(); - } -} - void SecureSocket::loadCertificates(const char* filename) { @@ -191,6 +154,36 @@ SecureSocket::loadCertificates(const char* filename) showError(); } } + +void +SecureSocket::initContext(bool server) +{ + SSL_library_init(); + + const SSL_METHOD* method; + + // load & register all cryptos, etc. + OpenSSL_add_all_algorithms(); + + // load all error messages + SSL_load_error_strings(); + + if (server) { + // create new server-method instance + method = SSLv23_server_method(); + } + else { + // create new client-method instance + method = SSLv23_client_method(); + } + + // create new context from method + m_ssl->m_context = SSL_CTX_new(method); + if (m_ssl->m_context == NULL) { + showError(); + } +} + void SecureSocket::createSSL() { @@ -201,7 +194,7 @@ SecureSocket::createSSL() } } -void +bool SecureSocket::secureAccept(int socket) { createSSL(); @@ -210,38 +203,46 @@ SecureSocket::secureAccept(int socket) SSL_set_fd(m_ssl->m_ssl, socket); // do SSL-protocol accept + LOG((CLOG_DEBUG "secureAccept")); int r = SSL_accept(m_ssl->m_ssl); - bool retry = checkResult(r); + + //TODO: don't use this infinite loop while (retry) { ARCH->sleep(.5f); - LOG((CLOG_INFO "secureAccept sleep .5s")); + SSL_set_fd(m_ssl->m_ssl, socket); r = SSL_accept(m_ssl->m_ssl); retry = checkResult(r); } - m_ready = true; + m_secureReady = !retry; + return retry; } -void +bool SecureSocket::secureConnect(int socket) { createSSL(); // attach the socket descriptor SSL_set_fd(m_ssl->m_ssl, socket); - + LOG((CLOG_DEBUG "secureConnect")); int r = SSL_connect(m_ssl->m_ssl); bool retry = checkResult(r); + + //TODO: don't use this infinite loop while (retry) { ARCH->sleep(.5f); r = SSL_connect(m_ssl->m_ssl); retry = checkResult(r); } - m_ready= true; + m_secureReady= true; showCertificate(); + + m_secureReady = !retry; + return retry; } void @@ -257,9 +258,6 @@ SecureSocket::showCertificate() line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0); LOG((CLOG_INFO "subject: %s", line)); OPENSSL_free(line); - line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0); - LOG((CLOG_INFO "issuer: %s", line)); - OPENSSL_free(line); X509_free(cert); } else { @@ -346,3 +344,34 @@ SecureSocket::getError() return errorUpdated; } + +ISocketMultiplexerJob* +SecureSocket::serviceConnect(ISocketMultiplexerJob* job, + bool, bool write, bool error) +{ + Lock lock(&getMutex()); + + bool retry = true; +#ifdef SYSAPI_WIN32 + retry = secureConnect(static_cast(getSocket()->m_socket)); +#elif SYSAPI_UNIX + retry = secureConnect(getSocket()->m_fd); +#endif + + return retry ? job : newJob(); +} + +ISocketMultiplexerJob* +SecureSocket::serviceAccept(ISocketMultiplexerJob* job, + bool, bool write, bool error) +{ + Lock lock(&getMutex()); + + bool retry = true; +#ifdef SYSAPI_WIN32 + retry = secureAccept(static_cast(getSocket()->m_socket)); +#elif SYSAPI_UNIX + retry = secureAccept(getSocket()->m_fd); +#endif + return retry ? job : newJob(); +} diff --git a/src/lib/plugin/ns/SecureSocket.h b/src/lib/plugin/ns/SecureSocket.h index e6aea016..bae6d909 100644 --- a/src/lib/plugin/ns/SecureSocket.h +++ b/src/lib/plugin/ns/SecureSocket.h @@ -22,6 +22,7 @@ class IEventQueue; class SocketMultiplexer; +class ISocketMultiplexerJob; struct Ssl; @@ -41,32 +42,38 @@ public: ArchSocket socket); ~SecureSocket(); - // IStream overrides - virtual UInt32 read(void* buffer, UInt32 n); - virtual void write(const void* buffer, UInt32 n); - virtual bool isReady() const; - void connectSecureSocket(); - void acceptSecureSocket(); + void secureConnect(); + void secureAccept(); + bool isSecureReady(); + bool isSecure() { return true; } + UInt32 secureRead(void* buffer, UInt32 n); + UInt32 secureWrite(const void* buffer, UInt32 n); void initSsl(bool server); void loadCertificates(const char* CertFile); private: - void onConnected(); - // SSL void initContext(bool server); void createSSL(); - void secureAccept(int s); - void secureConnect(int s); + bool secureAccept(int s); + bool secureConnect(int s); void showCertificate(); bool checkResult(int n); void showError(); void throwError(const char* reason); bool getError(); + ISocketMultiplexerJob* + serviceConnect(ISocketMultiplexerJob*, + bool, bool, bool); + + ISocketMultiplexerJob* + serviceAccept(ISocketMultiplexerJob*, + bool, bool, bool); + private: Ssl* m_ssl; - bool m_ready; + bool m_secureReady; char* m_error; }; diff --git a/src/lib/server/ClientListener.cpp b/src/lib/server/ClientListener.cpp index a4fd7638..3b4d872b 100644 --- a/src/lib/server/ClientListener.cpp +++ b/src/lib/server/ClientListener.cpp @@ -21,6 +21,7 @@ #include "server/ClientProxy.h" #include "server/ClientProxyUnknown.h" #include "synergy/PacketStreamFilter.h" +#include "net/TCPSocket.h" #include "net/IDataSocket.h" #include "net/IListenSocket.h" #include "net/ISocketFactory.h" @@ -142,7 +143,8 @@ ClientListener::handleClientConnecting(const Event&, void*) return; } LOG((CLOG_NOTE "accepted client connection")); - + TCPSocket* socket = dynamic_cast(stream); + socket->secureAccept(); // filter socket messages, including a packetizing filter if (m_streamFilterFactory != NULL) { stream = m_streamFilterFactory->create(stream, true);