Refactored secure read and write into SecureSocket

This commit is contained in:
Jerry (Xinyu Hou) 2016-08-24 17:52:02 +01:00 committed by Andrew Nelless
parent 61b489ab3d
commit 08a73218e6
4 changed files with 203 additions and 155 deletions

View File

@ -326,77 +326,39 @@ TCPSocket::init()
TCPSocket::EJobResult TCPSocket::EJobResult
TCPSocket::doRead() TCPSocket::doRead()
{ {
try { UInt8 buffer[4096];
static UInt8 buffer[4096]; memset(buffer, 0, sizeof(buffer));
memset(buffer, 0, sizeof(buffer)); size_t bytesRead = 0;
int bytesRead = 0;
int status = 0; bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer));
if (bytesRead > 0) {
bool wasEmpty = (m_inputBuffer.getSize() == 0);
if (isSecure()) { // slurp up as much as possible
if (isSecureReady()) { do {
status = secureRead(buffer, sizeof(buffer), bytesRead); m_inputBuffer.write(buffer, bytesRead);
if (status < 0) {
return kBreak; bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer));
} } while (bytesRead > 0);
else if (status == 0) {
return kNew;
}
}
else {
return kRetry;
}
}
else {
bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer));
}
if (bytesRead > 0) { // send input ready if input buffer was empty
bool wasEmpty = (m_inputBuffer.getSize() == 0); if (wasEmpty) {
sendEvent(m_events->forIStream().inputReady());
// slurp up as much as possible
do {
m_inputBuffer.write(buffer, bytesRead);
if (isSecure() && isSecureReady()) {
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
}
else {
bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer));
}
} while (bytesRead > 0 || status > 0);
// send input ready if input buffer was empty
if (wasEmpty) {
sendEvent(m_events->forIStream().inputReady());
}
}
else {
// remote write end of stream hungup. our input side
// has therefore shutdown but don't flush our buffer
// since there's still data to be read.
sendEvent(m_events->forIStream().inputShutdown());
if (!m_writable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
m_readable = false;
return kNew;
} }
} }
catch (XArchNetworkDisconnected&) { else {
// stream hungup // remote write end of stream hungup. our input side
sendEvent(m_events->forISocket().disconnected()); // has therefore shutdown but don't flush our buffer
onDisconnected(); // since there's still data to be read.
sendEvent(m_events->forIStream().inputShutdown());
if (!m_writable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
m_readable = false;
return kNew; return kNew;
} }
catch (XArchNetwork& e) {
// ignore other read error
LOG((CLOG_WARN "error reading socket: %s", e.what()));
}
return kRetry; return kRetry;
} }
@ -404,92 +366,16 @@ TCPSocket::doRead()
TCPSocket::EJobResult TCPSocket::EJobResult
TCPSocket::doWrite() TCPSocket::doWrite()
{ {
static bool s_retry = false; // write data
static int s_retrySize = 0; UInt32 bufferSize = 0;
static void* s_staticBuffer = NULL; int bytesWrote = 0;
try { bufferSize = m_outputBuffer.getSize();
// write data const void* buffer = m_outputBuffer.peek(bufferSize);
int bufferSize = 0; bytesWrote = (UInt32)ARCH->writeSocket(m_socket, buffer, bufferSize);
int bytesWrote = 0;
int status = 0; if (bytesWrote > 0) {
discardWrittenData(bytesWrote);
if (s_retry) {
bufferSize = s_retrySize;
}
else {
bufferSize = m_outputBuffer.getSize();
s_staticBuffer = malloc(bufferSize);
memcpy(s_staticBuffer, m_outputBuffer.peek(bufferSize), bufferSize);
}
if (bufferSize == 0) {
return kRetry;
}
if (isSecure()) {
if (isSecureReady()) {
status = secureWrite(s_staticBuffer, bufferSize, bytesWrote);
if (status > 0) {
s_retry = false;
bufferSize = 0;
free(s_staticBuffer);
s_staticBuffer = NULL;
}
else if (status < 0) {
return kBreak;
}
else if (status == 0) {
s_retry = true;
s_retrySize = bufferSize;
return kNew;
}
}
else {
return kRetry;
}
}
else {
bytesWrote = (UInt32)ARCH->writeSocket(m_socket, s_staticBuffer, bufferSize);
bufferSize = 0;
free(s_staticBuffer);
s_staticBuffer = NULL;
}
// discard written data
if (bytesWrote > 0) {
m_outputBuffer.pop(bytesWrote);
if (m_outputBuffer.getSize() == 0) {
sendEvent(m_events->forIStream().outputFlushed());
m_flushed = true;
m_flushed.broadcast();
return kNew;
}
}
}
catch (XArchNetworkShutdown&) {
// remote read end of stream hungup. our output side
// has therefore shutdown.
onOutputShutdown();
sendEvent(m_events->forIStream().outputShutdown());
if (!m_readable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
return kNew;
}
catch (XArchNetworkDisconnected&) {
// stream hungup
onDisconnected();
sendEvent(m_events->forISocket().disconnected());
return kNew;
}
catch (XArchNetwork& e) {
// other write error
LOG((CLOG_WARN "error writing socket: %s", e.what()));
onDisconnected();
sendEvent(m_events->forIStream().outputError());
sendEvent(m_events->forISocket().disconnected());
return kNew; return kNew;
} }
@ -550,6 +436,17 @@ TCPSocket::sendEvent(Event::Type type)
m_events->addEvent(Event(type, getEventTarget(), NULL)); m_events->addEvent(Event(type, getEventTarget(), NULL));
} }
void
TCPSocket::discardWrittenData(int bytesWrote)
{
m_outputBuffer.pop(bytesWrote);
if (m_outputBuffer.getSize() == 0) {
sendEvent(m_events->forIStream().outputFlushed());
m_flushed = true;
m_flushed.broadcast();
}
}
void void
TCPSocket::onConnected() TCPSocket::onConnected()
{ {
@ -643,11 +540,50 @@ TCPSocket::serviceConnected(ISocketMultiplexerJob* job,
EJobResult result = kRetry; EJobResult result = kRetry;
if (write) { if (write) {
result = doWrite(); try {
result = doWrite();
}
catch (XArchNetworkShutdown&) {
// remote read end of stream hungup. our output side
// has therefore shutdown.
onOutputShutdown();
sendEvent(m_events->forIStream().outputShutdown());
if (!m_readable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
result = kNew;
}
catch (XArchNetworkDisconnected&) {
// stream hungup
onDisconnected();
sendEvent(m_events->forISocket().disconnected());
result = kNew;
}
catch (XArchNetwork& e) {
// other write error
LOG((CLOG_WARN "error writing socket: %s", e.what()));
onDisconnected();
sendEvent(m_events->forIStream().outputError());
sendEvent(m_events->forISocket().disconnected());
result = kNew;
}
} }
if (read && m_readable) { if (read && m_readable) {
result = doRead(); try {
result = doRead();
}
catch (XArchNetworkDisconnected&) {
// stream hungup
sendEvent(m_events->forISocket().disconnected());
onDisconnected();
result = kNew;
}
catch (XArchNetwork& e) {
// ignore other read error
LOG((CLOG_WARN "error reading socket: %s", e.what()));
}
} }
return result == kBreak ? NULL : result == kNew ? newJob() : job; return result == kBreak ? NULL : result == kNew ? newJob() : job;

View File

@ -89,6 +89,7 @@ protected:
Mutex& getMutex() { return m_mutex; } Mutex& getMutex() { return m_mutex; }
void sendEvent(Event::Type); void sendEvent(Event::Type);
void discardWrittenData(int bytesWrote);
private: private:
void init(); void init();
@ -111,12 +112,12 @@ protected:
bool m_writable; bool m_writable;
bool m_connected; bool m_connected;
IEventQueue* m_events; IEventQueue* m_events;
StreamBuffer m_inputBuffer;
StreamBuffer m_outputBuffer;
private: private:
Mutex m_mutex; Mutex m_mutex;
ArchSocket m_socket; ArchSocket m_socket;
StreamBuffer m_inputBuffer;
StreamBuffer m_outputBuffer;
CondVar<bool> m_flushed; CondVar<bool> m_flushed;
SocketMultiplexer* m_socketMultiplexer; SocketMultiplexer* m_socketMultiplexer;
}; };

View File

@ -140,6 +140,115 @@ SecureSocket::secureAccept()
getSocket(), isReadable(), isWritable())); getSocket(), isReadable(), isWritable()));
} }
TCPSocket::EJobResult
SecureSocket::doRead()
{
static UInt8 buffer[4096];
memset(buffer, 0, sizeof(buffer));
int bytesRead = 0;
int status = 0;
if (isSecureReady()) {
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
else if (status == 0) {
return kNew;
}
}
else {
return kRetry;
}
if (bytesRead > 0) {
bool wasEmpty = (m_inputBuffer.getSize() == 0);
// slurp up as much as possible
do {
m_inputBuffer.write(buffer, bytesRead);
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
} while (bytesRead > 0 || status > 0);
// send input ready if input buffer was empty
if (wasEmpty) {
sendEvent(m_events->forIStream().inputReady());
}
}
else {
// remote write end of stream hungup. our input side
// has therefore shutdown but don't flush our buffer
// since there's still data to be read.
sendEvent(m_events->forIStream().inputShutdown());
if (!m_writable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
m_readable = false;
return kNew;
}
return kRetry;
}
TCPSocket::EJobResult
SecureSocket::doWrite()
{
static bool s_retry = false;
static int s_retrySize = 0;
static void* s_staticBuffer = NULL;
// write data
int bufferSize = 0;
int bytesWrote = 0;
int status = 0;
if (s_retry) {
bufferSize = s_retrySize;
}
else {
bufferSize = m_outputBuffer.getSize();
s_staticBuffer = malloc(bufferSize);
memcpy(s_staticBuffer, m_outputBuffer.peek(bufferSize), bufferSize);
}
if (bufferSize == 0) {
return kRetry;
}
if (isSecureReady()) {
status = secureWrite(s_staticBuffer, bufferSize, bytesWrote);
if (status > 0) {
s_retry = false;
bufferSize = 0;
free(s_staticBuffer);
s_staticBuffer = NULL;
}
else if (status < 0) {
return kBreak;
}
else if (status == 0) {
s_retry = true;
s_retrySize = bufferSize;
return kNew;
}
}
else {
return kRetry;
}
if (bytesWrote > 0) {
discardWrittenData(bytesWrote);
return kNew;
}
return kRetry;
}
int int
SecureSocket::secureRead(void* buffer, int size, int& read) SecureSocket::secureRead(void* buffer, int size, int& read)
{ {

View File

@ -48,6 +48,8 @@ public:
newJob(); newJob();
void secureConnect(); void secureConnect();
void secureAccept(); void secureAccept();
EJobResult doRead();
EJobResult doWrite();
bool isReady() const { return m_secureReady; } bool isReady() const { return m_secureReady; }
bool isFatal() const { return m_fatal; } bool isFatal() const { return m_fatal; }
void isFatal(bool b) { m_fatal = b; } void isFatal(bool b) { m_fatal = b; }