refactored SecureSocket to use interface #4313

This commit is contained in:
Xinyu Hou 2015-01-14 17:24:45 +00:00 committed by XinyuHou
parent be2b87fd39
commit 141b778477
28 changed files with 484 additions and 158 deletions

View File

@ -30,6 +30,11 @@ Arch::Arch()
s_instance = this; s_instance = this;
} }
Arch::Arch(Arch* arch)
{
s_instance = arch;
}
Arch::~Arch() Arch::~Arch()
{ {
#if SYSAPI_WIN32 #if SYSAPI_WIN32

View File

@ -99,6 +99,7 @@ class Arch : public ARCH_CONSOLE,
public ARCH_TIME { public ARCH_TIME {
public: public:
Arch(); Arch();
Arch(Arch* arch);
virtual ~Arch(); virtual ~Arch();
//! Call init on other arch classes. //! Call init on other arch classes.

View File

@ -42,11 +42,17 @@ public:
*/ */
virtual void load() = 0; virtual void load() = 0;
//! Init plugins //! Init the common parts
/*! /*!
Initializes loaded plugins. Initializes common parts like log and arch.
*/ */
virtual void init(void* eventTarget, IEventQueue* events) = 0; virtual void init(void* log, void* arch) = 0;
//! Init the event part
/*!
Initializes event parts.
*/
virtual void initEvent(void* eventTarget, IEventQueue* events) = 0;
//! Check if exists //! Check if exists
/*! /*!
@ -60,7 +66,7 @@ public:
*/ */
virtual void* invoke(const char* plugin, virtual void* invoke(const char* plugin,
const char* command, const char* command,
void* args) = 0; void** args) = 0;
//@} //@}

View File

@ -28,7 +28,8 @@
#include <dirent.h> #include <dirent.h>
#include <dlfcn.h> #include <dlfcn.h>
typedef int (*initFunc)(void (*sendEvent)(const char*, void*), void (*log)(const char*)); typedef void (*initFunc)(void*, void*);
typedef int (*initEventFunc)(void (*sendEvent)(const char*, void*));
typedef void* (*invokeFunc)(const char*, void*); typedef void* (*invokeFunc)(const char*, void*);
void* g_eventTarget = NULL; void* g_eventTarget = NULL;
@ -84,15 +85,25 @@ ArchPluginUnix::load()
} }
void void
ArchPluginUnix::init(void* eventTarget, IEventQueue* events) ArchPluginUnix::init(void* log, void* arch)
{
PluginTable::iterator it;
for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) {
initFunc initPlugin = (initFunc)dlsym(it->second, "init");
initPlugin(log, arch);
}
}
void
ArchPluginUnix::initEvent(void* eventTarget, IEventQueue* events)
{ {
g_eventTarget = eventTarget; g_eventTarget = eventTarget;
g_events = events; g_events = events;
PluginTable::iterator it; PluginTable::iterator it;
for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) { for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) {
initFunc initPlugin = (initFunc)dlsym(it->second, "init"); initEventFunc initEventPlugin = (initEventFunc)dlsym(it->second, "initEvent");
initPlugin(&sendEvent, &log); initEventPlugin(&sendEvent);
} }
} }
@ -108,7 +119,7 @@ void*
ArchPluginUnix::invoke( ArchPluginUnix::invoke(
const char* plugin, const char* plugin,
const char* command, const char* command,
void* args) void** args)
{ {
PluginTable::iterator it; PluginTable::iterator it;
it = m_pluginTable.find(plugin); it = m_pluginTable.find(plugin);

View File

@ -32,11 +32,12 @@ public:
// IArchPlugin overrides // IArchPlugin overrides
void load(); void load();
void init(void* eventTarget, IEventQueue* events); void init(void* log, void* arch);
void initEvent(void* eventTarget, IEventQueue* events);
bool exists(const char* name); bool exists(const char* name);
virtual void* invoke(const char* pluginName, virtual void* invoke(const char* pluginName,
const char* functionName, const char* functionName,
void* args); void** args);
private: private:
String getPluginsDir(); String getPluginsDir();

View File

@ -27,8 +27,9 @@
#include <Windows.h> #include <Windows.h>
#include <iostream> #include <iostream>
typedef int (*initFunc)(void (*sendEvent)(const char*, void*), void (*log)(const char*)); typedef int (*initFunc)(void*, void*);
typedef void* (*invokeFunc)(const char*, void*); typedef int (*initEventFunc)(void (*sendEvent)(const char*, void*));
typedef void* (*invokeFunc)(const char*, void**);
void* g_eventTarget = NULL; void* g_eventTarget = NULL;
IEventQueue* g_events = NULL; IEventQueue* g_events = NULL;
@ -68,7 +69,19 @@ ArchPluginWindows::load()
} }
void void
ArchPluginWindows::init(void* eventTarget, IEventQueue* events) ArchPluginWindows::init(void* log, void* arch)
{
PluginTable::iterator it;
HINSTANCE lib;
for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) {
lib = reinterpret_cast<HINSTANCE>(it->second);
initFunc initPlugin = (initFunc)GetProcAddress(lib, "init");
initPlugin(log, arch);
}
}
void
ArchPluginWindows::initEvent(void* eventTarget, IEventQueue* events)
{ {
g_eventTarget = eventTarget; g_eventTarget = eventTarget;
g_events = events; g_events = events;
@ -77,8 +90,8 @@ ArchPluginWindows::init(void* eventTarget, IEventQueue* events)
HINSTANCE lib; HINSTANCE lib;
for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) { for (it = m_pluginTable.begin(); it != m_pluginTable.end(); it++) {
lib = reinterpret_cast<HINSTANCE>(it->second); lib = reinterpret_cast<HINSTANCE>(it->second);
initFunc initPlugin = (initFunc)GetProcAddress(lib, "init"); initEventFunc initEventPlugin = (initEventFunc)GetProcAddress(lib, "initEvent");
initPlugin(&sendEvent, &log); initEventPlugin(&sendEvent);
} }
} }
@ -95,7 +108,7 @@ void*
ArchPluginWindows::invoke( ArchPluginWindows::invoke(
const char* plugin, const char* plugin,
const char* command, const char* command,
void* args) void** args)
{ {
PluginTable::iterator it; PluginTable::iterator it;
it = m_pluginTable.find(plugin); it = m_pluginTable.find(plugin);

View File

@ -35,11 +35,12 @@ public:
// IArchPlugin overrides // IArchPlugin overrides
void load(); void load();
void init(void* eventTarget, IEventQueue* events); void init(void* log, void* arch);
void initEvent(void* eventTarget, IEventQueue* events);
bool exists(const char* name); bool exists(const char* name);
void* invoke(const char* pluginName, void* invoke(const char* pluginName,
const char* functionName, const char* functionName,
void* args); void** args);
private: private:
String getModuleDir(); String getModuleDir();

View File

@ -74,6 +74,11 @@ Log::Log()
s_log = this; s_log = this;
} }
Log::Log(Log* src)
{
s_log = src;
}
Log::~Log() Log::~Log()
{ {
// clean up // clean up

View File

@ -41,6 +41,7 @@ LOGC() provide convenient access.
class Log { class Log {
public: public:
Log(); Log();
Log(Log* src);
~Log(); ~Log();
//! @name manipulators //! @name manipulators

View File

@ -171,7 +171,7 @@ findReplaceAll(
String String
removeFileExt(String filename) removeFileExt(String filename)
{ {
unsigned dot = filename.find_last_of('.'); size_t dot = filename.find_last_of('.');
if (dot == String::npos) { if (dot == String::npos) {
return filename; return filename;

View File

@ -82,8 +82,7 @@ Client::Client(
m_crypto(crypto), m_crypto(crypto),
m_sendFileThread(NULL), m_sendFileThread(NULL),
m_writeToDropDirThread(NULL), m_writeToDropDirThread(NULL),
m_enableDragDrop(enableDragDrop), m_enableDragDrop(enableDragDrop)
m_secureSocket(NULL)
{ {
assert(m_socketFactory != NULL); assert(m_socketFactory != NULL);
assert(m_screen != NULL); assert(m_screen != NULL);
@ -108,11 +107,6 @@ Client::Client(
new TMethodEventJob<Client>(this, new TMethodEventJob<Client>(this,
&Client::handleFileRecieveCompleted)); &Client::handleFileRecieveCompleted));
} }
if (ARCH->plugin().exists(s_networkSecurity)) {
m_secureSocket = static_cast<SecureSocket*>(
ARCH->plugin().invoke("ns", "getSecureSocket", NULL));
}
} }
Client::~Client() Client::~Client()
@ -163,14 +157,16 @@ Client::connect()
} }
// create the socket // create the socket
IDataSocket* socket = m_socketFactory->create(); bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity);
IDataSocket* socket = m_socketFactory->create(useSecureSocket);
// filter socket messages, including a packetizing filter // filter socket messages, including a packetizing filter
m_stream = socket; m_stream = socket;
bool adopt = !useSecureSocket;
if (m_streamFilterFactory != NULL) { if (m_streamFilterFactory != NULL) {
m_stream = m_streamFilterFactory->create(m_stream, true); m_stream = m_streamFilterFactory->create(m_stream, adopt);
} }
m_stream = new PacketStreamFilter(m_events, m_stream, true); m_stream = new PacketStreamFilter(m_events, m_stream, adopt);
if (m_crypto.m_mode != kDisabled) { if (m_crypto.m_mode != kDisabled) {
m_cryptoStream = new CryptoStream( m_cryptoStream = new CryptoStream(
@ -187,8 +183,7 @@ Client::connect()
catch (XBase& e) { catch (XBase& e) {
cleanupTimer(); cleanupTimer();
cleanupConnecting(); cleanupConnecting();
delete m_stream; cleanupStream();
m_stream = NULL;
LOG((CLOG_DEBUG1 "connection failed")); LOG((CLOG_DEBUG1 "connection failed"));
sendConnectionFailedEvent(e.what()); sendConnectionFailedEvent(e.what());
return; return;
@ -545,8 +540,7 @@ Client::cleanupConnection()
m_stream->getEventTarget()); m_stream->getEventTarget());
m_events->removeHandler(m_events->forISocket().disconnected(), m_events->removeHandler(m_events->forISocket().disconnected(),
m_stream->getEventTarget()); m_stream->getEventTarget());
delete m_stream; cleanupStream();
m_stream = NULL;
} }
} }
@ -577,6 +571,16 @@ Client::cleanupTimer()
} }
} }
void
Client::cleanupStream()
{
bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity);
if (!useSecureSocket) {
delete m_stream;
m_stream = NULL;
}
}
void void
Client::handleConnected(const Event&, void*) Client::handleConnected(const Event&, void*)
{ {
@ -600,8 +604,7 @@ Client::handleConnectionFailed(const Event& event, void*)
cleanupTimer(); cleanupTimer();
cleanupConnecting(); cleanupConnecting();
delete m_stream; cleanupStream();
m_stream = NULL;
LOG((CLOG_DEBUG1 "connection failed")); LOG((CLOG_DEBUG1 "connection failed"));
sendConnectionFailedEvent(info->m_what.c_str()); sendConnectionFailedEvent(info->m_what.c_str());
delete info; delete info;
@ -613,8 +616,7 @@ Client::handleConnectTimeout(const Event&, void*)
cleanupTimer(); cleanupTimer();
cleanupConnecting(); cleanupConnecting();
cleanupConnection(); cleanupConnection();
delete m_stream; cleanupStream();
m_stream = NULL;
LOG((CLOG_DEBUG1 "connection timed out")); LOG((CLOG_DEBUG1 "connection timed out"));
sendConnectionFailedEvent("Timed out"); sendConnectionFailedEvent("Timed out");
} }

View File

@ -190,6 +190,7 @@ private:
void cleanupConnection(); void cleanupConnection();
void cleanupScreen(); void cleanupScreen();
void cleanupTimer(); void cleanupTimer();
void cleanupStream();
void handleConnected(const Event&, void*); void handleConnected(const Event&, void*);
void handleConnectionFailed(const Event&, void*); void handleConnectionFailed(const Event&, void*);
void handleConnectTimeout(const Event&, void*); void handleConnectTimeout(const Event&, void*);
@ -211,7 +212,8 @@ private:
String m_name; String m_name;
NetworkAddress m_serverAddress; NetworkAddress m_serverAddress;
ISocketFactory* m_socketFactory; ISocketFactory* m_socketFactory;
IStreamFilterFactory* m_streamFilterFactory; IStreamFilterFactory*
m_streamFilterFactory;
synergy::Screen* m_screen; synergy::Screen* m_screen;
synergy::IStream* m_stream; synergy::IStream* m_stream;
EventQueueTimer* m_timer; EventQueueTimer* m_timer;
@ -234,5 +236,4 @@ private:
Thread* m_sendFileThread; Thread* m_sendFileThread;
Thread* m_writeToDropDirThread; Thread* m_writeToDropDirThread;
bool m_enableDragDrop; bool m_enableDragDrop;
SecureSocket* m_secureSocket;
}; };

View File

@ -34,10 +34,10 @@ public:
//@{ //@{
//! Create data socket //! Create data socket
virtual IDataSocket* create() const = 0; virtual IDataSocket* create(bool secure) const = 0;
//! Create listen socket //! Create listen socket
virtual IListenSocket* createListen() const = 0; virtual IListenSocket* createListen(bool secure) const = 0;
//@} //@}
}; };

View File

@ -33,7 +33,7 @@ A listen socket using TCP.
class TCPListenSocket : public IListenSocket { class TCPListenSocket : public IListenSocket {
public: public:
TCPListenSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer); TCPListenSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer);
~TCPListenSocket(); virtual ~TCPListenSocket();
// ISocket overrides // ISocket overrides
virtual void bind(const NetworkAddress&); virtual void bind(const NetworkAddress&);
@ -44,12 +44,14 @@ public:
virtual IDataSocket* virtual IDataSocket*
accept(); accept();
private: ArchSocket& getSocket() { return m_socket; }
public:
ISocketMultiplexerJob* ISocketMultiplexerJob*
serviceListening(ISocketMultiplexerJob*, serviceListening(ISocketMultiplexerJob*,
bool, bool, bool); bool, bool, bool);
private: protected:
ArchSocket m_socket; ArchSocket m_socket;
Mutex* m_mutex; Mutex* m_mutex;
IEventQueue* m_events; IEventQueue* m_events;

View File

@ -38,7 +38,7 @@ class TCPSocket : public IDataSocket {
public: public:
TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer); TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer);
TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer, ArchSocket socket); TCPSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer, ArchSocket socket);
~TCPSocket(); virtual ~TCPSocket();
// ISocket overrides // ISocket overrides
virtual void bind(const NetworkAddress&); virtual void bind(const NetworkAddress&);
@ -57,15 +57,19 @@ public:
// IDataSocket overrides // IDataSocket overrides
virtual void connect(const NetworkAddress&); virtual void connect(const NetworkAddress&);
protected:
virtual void onConnected();
ArchSocket getSocket() { return m_socket; }
private: private:
void init(); void init();
void setJob(ISocketMultiplexerJob*); void setJob(ISocketMultiplexerJob*);
ISocketMultiplexerJob* newJob(); ISocketMultiplexerJob*
newJob();
void sendConnectionFailedEvent(const char*); void sendConnectionFailedEvent(const char*);
void sendEvent(Event::Type); void sendEvent(Event::Type);
void onConnected();
void onInputShutdown(); void onInputShutdown();
void onOutputShutdown(); void onOutputShutdown();
void onDisconnected(); void onDisconnected();

View File

@ -20,11 +20,19 @@
#include "net/TCPSocket.h" #include "net/TCPSocket.h"
#include "net/TCPListenSocket.h" #include "net/TCPListenSocket.h"
#include "arch/Arch.h"
#include "base/Log.h"
// //
// TCPSocketFactory // TCPSocketFactory
// //
#if defined _WIN32
static const char s_networkSecurity[] = { "ns" };
#else
static const char s_networkSecurity[] = { "libns" };
#endif
TCPSocketFactory::TCPSocketFactory(IEventQueue* events, SocketMultiplexer* socketMultiplexer) : TCPSocketFactory::TCPSocketFactory(IEventQueue* events, SocketMultiplexer* socketMultiplexer) :
m_events(events), m_events(events),
m_socketMultiplexer(socketMultiplexer) m_socketMultiplexer(socketMultiplexer)
@ -38,13 +46,43 @@ TCPSocketFactory::~TCPSocketFactory()
} }
IDataSocket* IDataSocket*
TCPSocketFactory::create() const TCPSocketFactory::create(bool secure) const
{ {
return new TCPSocket(m_events, m_socketMultiplexer); IDataSocket* socket = NULL;
if (secure) {
void* args[4] = {
m_events,
m_socketMultiplexer,
Log::getInstance(),
Arch::getInstance()
};
socket = static_cast<IDataSocket*>(
ARCH->plugin().invoke(s_networkSecurity, "getSecureSocket", args));
}
else {
socket = new TCPSocket(m_events, m_socketMultiplexer);
}
return socket;
} }
IListenSocket* IListenSocket*
TCPSocketFactory::createListen() const TCPSocketFactory::createListen(bool secure) const
{ {
return new TCPListenSocket(m_events, m_socketMultiplexer); IListenSocket* socket = NULL;
if (secure) {
void* args[4] = {
m_events,
m_socketMultiplexer,
Log::getInstance(),
Arch::getInstance()
};
socket = static_cast<IListenSocket*>(
ARCH->plugin().invoke(s_networkSecurity, "getSecureListenSocket", args));
}
else {
socket = new TCPListenSocket(m_events, m_socketMultiplexer);
}
return socket;
} }

View File

@ -31,9 +31,9 @@ public:
// ISocketFactory overrides // ISocketFactory overrides
virtual IDataSocket* virtual IDataSocket*
create() const; create(bool secure) const;
virtual IListenSocket* virtual IListenSocket*
createListen() const; createListen(bool secure) const;
private: private:
IEventQueue* m_events; IEventQueue* m_events;

View File

@ -36,6 +36,12 @@
// ClientListener // ClientListener
// //
#if defined _WIN32
static const char s_networkSecurity[] = { "ns" };
#else
static const char s_networkSecurity[] = { "libns" };
#endif
ClientListener::ClientListener(const NetworkAddress& address, ClientListener::ClientListener(const NetworkAddress& address,
ISocketFactory* socketFactory, ISocketFactory* socketFactory,
IStreamFilterFactory* streamFilterFactory, IStreamFilterFactory* streamFilterFactory,
@ -51,7 +57,8 @@ ClientListener::ClientListener(const NetworkAddress& address,
try { try {
// create listen socket // create listen socket
m_listen = m_socketFactory->createListen(); bool useSecureSocket = ARCH->plugin().exists(s_networkSecurity);
m_listen = m_socketFactory->createListen(useSecureSocket);
// bind listen address // bind listen address
LOG((CLOG_DEBUG1 "binding listen socket")); LOG((CLOG_DEBUG1 "binding listen socket"));

View File

@ -460,6 +460,8 @@ ClientApp::mainLoop()
// load all available plugins. // load all available plugins.
ARCH->plugin().load(); ARCH->plugin().load();
// pass log and arch into plugins.
ARCH->plugin().init(Log::getInstance(), Arch::getInstance());
// start client, etc // start client, etc
appUtil().startNode(); appUtil().startNode();
@ -470,8 +472,8 @@ ClientApp::mainLoop()
initIpcClient(); initIpcClient();
} }
// init all available plugins. // init event for all available plugins.
ARCH->plugin().init(m_clientScreen->getEventTarget(), m_events); ARCH->plugin().initEvent(m_clientScreen->getEventTarget(), m_events);
// run event loop. if startClient() failed we're supposed to retry // run event loop. if startClient() failed we're supposed to retry
// later. the timer installed by startClient() will take care of // later. the timer installed by startClient() will take care of

View File

@ -707,6 +707,11 @@ ServerApp::mainLoop()
return kExitFailed; return kExitFailed;
} }
// load all available plugins.
ARCH->plugin().load();
// pass log and arch into plugins.
ARCH->plugin().init(Log::getInstance(), Arch::getInstance());
// start server, etc // start server, etc
appUtil().startNode(); appUtil().startNode();
@ -716,8 +721,8 @@ ServerApp::mainLoop()
initIpcClient(); initIpcClient();
} }
// load all available plugins. // init event for all available plugins.
ARCH->plugin().init(m_serverScreen->getEventTarget(), m_events); ARCH->plugin().initEvent(m_serverScreen->getEventTarget(), m_events);
// handle hangup signal by reloading the server's configuration // handle hangup signal by reloading the server's configuration
ARCH->setSignalHandler(Arch::kHANGUP, &reloadSignalHandler, NULL); ARCH->setSignalHandler(Arch::kHANGUP, &reloadSignalHandler, NULL);

View File

@ -0,0 +1,77 @@
/*
* synergy -- mouse and keyboard sharing utility
* Copyright (C) 2015 Synergy Si Ltd.
*
* This package is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* found in the file COPYING that should have accompanied this file.
*
* This package is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "SecureListenSocket.h"
#include "SecureSocket.h"
#include "net/NetworkAddress.h"
#include "net/SocketMultiplexer.h"
#include "net/TSocketMultiplexerMethodJob.h"
#include "arch/XArch.h"
#include "base/Log.h"
//
// SecureListenSocket
//
SecureListenSocket::SecureListenSocket(
IEventQueue* events,
SocketMultiplexer* socketMultiplexer) :
TCPListenSocket(events, socketMultiplexer)
{
}
SecureListenSocket::~SecureListenSocket()
{
}
IDataSocket*
SecureListenSocket::accept()
{
SecureSocket* socket = NULL;
try {
socket = new SecureSocket(
m_events,
m_socketMultiplexer,
ARCH->acceptSocket(m_socket, NULL));
socket->initSsl(true);
// TODO: customized certificate path
socket->loadCertificates("C:\\Temp\\synergy.pem");
if (socket != NULL) {
m_socketMultiplexer->addSocket(this,
new TSocketMultiplexerMethodJob<TCPListenSocket>(
this, &TCPListenSocket::serviceListening,
m_socket, true, false));
socket->acceptSecureSocket();
}
return dynamic_cast<IDataSocket*>(socket);
}
catch (XArchNetwork&) {
if (socket != NULL) {
delete socket;
}
return NULL;
}
catch (std::exception &ex) {
if (socket != NULL) {
delete socket;
}
throw ex;
}
}

View File

@ -0,0 +1,34 @@
/*
* synergy -- mouse and keyboard sharing utility
* Copyright (C) 2015 Synergy Si Ltd.
*
* This package is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* found in the file COPYING that should have accompanied this file.
*
* This package is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#pragma once
#include "net/TCPListenSocket.h"
class IEventQueue;
class SocketMultiplexer;
class SecureListenSocket : public TCPListenSocket{
public:
SecureListenSocket(IEventQueue* events,
SocketMultiplexer* socketMultiplexer);
~SecureListenSocket();
// IListenSocket overrides
virtual IDataSocket*
accept();
};

View File

@ -17,26 +17,42 @@
#include "SecureSocket.h" #include "SecureSocket.h"
#include "base/String.h" #include "net/TCPSocket.h"
#include "arch/XArch.h"
#include "base/Log.h"
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/err.h> #include <openssl/err.h>
#include <cstring>
#include <cstdlib>
#include <memory>
// //
// SecureSocket // SecureSocket
// //
#define MAX_ERROR_SIZE 65535
struct Ssl { struct Ssl {
SSL_CTX* m_context; SSL_CTX* m_context;
SSL* m_ssl; SSL* m_ssl;
}; };
SecureSocket::SecureSocket() : SecureSocket::SecureSocket(
m_ready(false), IEventQueue* events,
m_errorSize(65535) SocketMultiplexer* socketMultiplexer) :
TCPSocket(events, socketMultiplexer),
m_ready(false)
{
}
SecureSocket::SecureSocket(
IEventQueue* events,
SocketMultiplexer* socketMultiplexer,
ArchSocket socket) :
TCPSocket(events, socketMultiplexer, socket),
m_ready(false)
{ {
m_ssl = new Ssl();
m_ssl->m_context = NULL;
m_ssl->m_ssl = NULL;
m_error = new char[m_errorSize];
} }
SecureSocket::~SecureSocket() SecureSocket::~SecureSocket()
@ -51,6 +67,81 @@ SecureSocket::~SecureSocket()
delete[] m_error; delete[] m_error;
} }
UInt32
SecureSocket::read(void* buffer, UInt32 n)
{
bool retry = false;
int r = 0;
if (m_ssl != NULL) {
r = SSL_read(m_ssl->m_ssl, buffer, n);
retry = checkResult(r);
if (retry) {
r = 0;
}
}
return r > 0 ? (UInt32)r : 0;
}
void
SecureSocket::write(const void* buffer, UInt32 n)
{
bool retry = false;
int r = 0;
if (m_ssl != NULL) {
r = SSL_write(m_ssl->m_ssl, buffer, n);
retry = checkResult(r);
if (retry) {
r = 0;
}
}
}
bool
SecureSocket::isReady() const
{
return m_ready;
}
void
SecureSocket::connectSecureSocket()
{
#ifdef SYSAPI_WIN32
secureConnect(static_cast<int>(getSocket()->m_socket));
#elif SYSAPI_UNIX
secureConnect(getSocket()->m_fd);
#endif
}
void
SecureSocket::acceptSecureSocket()
{
#ifdef SYSAPI_WIN32
secureAccept(static_cast<int>(getSocket()->m_socket));
#elif SYSAPI_UNIX
secureAccept(getSocket()->m_fd);
#endif
}
void
SecureSocket::initSsl(bool server)
{
m_ssl = new Ssl();
m_ssl->m_context = NULL;
m_ssl->m_ssl = NULL;
m_error = new char[MAX_ERROR_SIZE];
initContext(server);
}
void
SecureSocket::onConnected()
{
TCPSocket::onConnected();
connectSecureSocket();
}
void void
SecureSocket::initContext(bool server) SecureSocket::initContext(bool server)
{ {
@ -111,7 +202,7 @@ SecureSocket::createSSL()
} }
void void
SecureSocket::accept(int socket) SecureSocket::secureAccept(int socket)
{ {
createSSL(); createSSL();
@ -124,6 +215,7 @@ SecureSocket::accept(int socket)
bool retry = checkResult(r); bool retry = checkResult(r);
while (retry) { while (retry) {
ARCH->sleep(.5f); ARCH->sleep(.5f);
LOG((CLOG_INFO "secureAccept sleep .5s"));
r = SSL_accept(m_ssl->m_ssl); r = SSL_accept(m_ssl->m_ssl);
retry = checkResult(r); retry = checkResult(r);
} }
@ -132,7 +224,7 @@ SecureSocket::accept(int socket)
} }
void void
SecureSocket::connect(int socket) SecureSocket::secureConnect(int socket)
{ {
createSSL(); createSSL();
@ -152,38 +244,6 @@ SecureSocket::connect(int socket)
showCertificate(); showCertificate();
} }
size_t
SecureSocket::write(const void* buffer, int size)
{
bool retry = false;
int n = 0;
if (m_ssl != NULL) {
n = SSL_write(m_ssl->m_ssl, buffer, size);
retry = checkResult(n);
if (retry) {
n = 0;
}
}
return n > 0 ? n : 0;
}
size_t
SecureSocket::read(void* buffer, int size)
{
bool retry = false;
int n = 0;
if (m_ssl != NULL) {
n = SSL_read(m_ssl->m_ssl, buffer, size);
retry = checkResult(n);
if (retry) {
n = 0;
}
}
return n > 0 ? n : 0;
}
void void
SecureSocket::showCertificate() SecureSocket::showCertificate()
{ {
@ -277,7 +337,7 @@ SecureSocket::getError()
bool errorUpdated = false; bool errorUpdated = false;
if (e != 0) { if (e != 0) {
ERR_error_string_n(e, m_error, m_errorSize); ERR_error_string_n(e, m_error, MAX_ERROR_SIZE);
errorUpdated = true; errorUpdated = true;
} }
else { else {
@ -286,9 +346,3 @@ SecureSocket::getError()
return errorUpdated; return errorUpdated;
} }
bool
SecureSocket::isReady()
{
return m_ready;
}

View File

@ -17,34 +17,48 @@
#pragma once #pragma once
#include "net/TCPSocket.h"
#include "base/XBase.h" #include "base/XBase.h"
#include "base/Log.h"
class IEventQueue;
class SocketMultiplexer;
struct Ssl;
//! Generic socket exception //! Generic socket exception
XBASE_SUBCLASS(XSecureSocket, XBase); XBASE_SUBCLASS(XSecureSocket, XBase);
//! SSL
//! Secure socket
/*! /*!
Secure socket layer using OpenSSL. A secure socket using SSL.
*/ */
class SecureSocket : public TCPSocket {
struct Ssl;
class SecureSocket {
public: public:
SecureSocket(); SecureSocket(IEventQueue* events, SocketMultiplexer* socketMultiplexer);
SecureSocket(IEventQueue* events,
SocketMultiplexer* socketMultiplexer,
ArchSocket socket);
~SecureSocket(); ~SecureSocket();
void initContext(bool server); // IStream overrides
virtual UInt32 read(void* buffer, UInt32 n);
virtual void write(const void* buffer, UInt32 n);
virtual bool isReady() const;
void connectSecureSocket();
void acceptSecureSocket();
void initSsl(bool server);
void loadCertificates(const char* CertFile); void loadCertificates(const char* CertFile);
void createSSL();
void accept(int s);
void connect(int s);
size_t write(const void* buffer, int size);
size_t read(void* buffer, int size);
bool isReady();
private: private:
void onConnected();
// SSL
void initContext(bool server);
void createSSL();
void secureAccept(int s);
void secureConnect(int s);
void showCertificate(); void showCertificate();
bool checkResult(int n); bool checkResult(int n);
void showError(); void showError();
@ -55,5 +69,4 @@ private:
Ssl* m_ssl; Ssl* m_ssl;
bool m_ready; bool m_ready;
char* m_error; char* m_error;
const size_t m_errorSize;
}; };

View File

@ -18,28 +18,62 @@
#include "ns.h" #include "ns.h"
#include "SecureSocket.h" #include "SecureSocket.h"
#include "SecureListenSocket.h"
#include "arch/Arch.h"
#include "base/Log.h"
#include <iostream> #include <iostream>
SecureSocket* g_secureSocket = NULL; SecureSocket* g_secureSocket = NULL;
SecureListenSocket* g_secureListenSocket = NULL;
Arch* g_arch = NULL;
Log* g_log = NULL;
extern "C" { extern "C" {
void
init(void* log, void* arch)
{
if (g_log == NULL) {
g_log = new Log(reinterpret_cast<Log*>(log));
}
if (g_arch == NULL) {
g_arch = new Arch(reinterpret_cast<Arch*>(arch));
}
}
int int
init(void (*sendEvent)(const char*, void*), void (*log)(const char*)) initEvent(void (*sendEvent)(const char*, void*))
{ {
return 0; return 0;
} }
void* void*
invoke(const char* command, void* args) invoke(const char* command, void** args)
{ {
if (strcmp(command, "getSecureSocket") == 0) { IEventQueue* arg1 = NULL;
if (g_secureSocket == NULL) { SocketMultiplexer* arg2 = NULL;
g_secureSocket = new SecureSocket(); if (args != NULL) {
arg1 = reinterpret_cast<IEventQueue*>(args[0]);
arg2 = reinterpret_cast<SocketMultiplexer*>(args[1]);
} }
if (strcmp(command, "getSecureSocket") == 0) {
if (g_secureSocket != NULL) {
delete g_secureSocket;
}
g_secureSocket = new SecureSocket(arg1, arg2);
g_secureSocket->initSsl(false);
return g_secureSocket; return g_secureSocket;
} }
else if (strcmp(command, "getSecureListenSocket") == 0) {
if (g_secureListenSocket != NULL) {
delete g_secureListenSocket;
}
g_secureListenSocket = new SecureListenSocket(arg1, arg2);
return g_secureListenSocket;
}
else { else {
return NULL; return NULL;
} }
@ -52,6 +86,10 @@ cleanup()
delete g_secureSocket; delete g_secureSocket;
} }
if (g_secureListenSocket != NULL) {
delete g_secureListenSocket;
}
return 0; return 0;
} }

View File

@ -33,8 +33,9 @@
extern "C" { extern "C" {
NS_API int init(void (*sendEvent)(const char*, void*), void (*log)(const char*)); NS_API void init(void* log, void* arch);
NS_API void* invoke(const char* command, void* args); NS_API int initEvent(void (*sendEvent)(const char*, void*));
NS_API void* invoke(const char* command, void** args);
NS_API int cleanup(); NS_API int cleanup();
} }

View File

@ -36,12 +36,15 @@ static void (*s_log)(const char*) = NULL;
extern "C" { extern "C" {
void
init(void* log, void* arch)
{
}
int int
init(void (*sendEvent)(const char*, void*), void (*log)(const char*)) initEvent(void (*sendEvent)(const char*, void*))
{ {
s_sendEvent = sendEvent; s_sendEvent = sendEvent;
s_log = log;
LOG("init");
CreateThread(NULL, 0, mainLoop, NULL, 0, NULL); CreateThread(NULL, 0, mainLoop, NULL, 0, NULL);
return 0; return 0;
} }

View File

@ -29,7 +29,8 @@
extern "C" { extern "C" {
WINMMJOY_API int init(void (*sendEvent)(const char*, void*), void (*log)(const char*)); WINMMJOY_API void init(void* log, void* arch);
WINMMJOY_API int initEvent(void (*sendEvent)(const char*, void*));
WINMMJOY_API int cleanup(); WINMMJOY_API int cleanup();
} }