Refactored secure read and write into SecureSocket

This commit is contained in:
Jerry (Xinyu Hou) 2016-08-24 17:52:02 +01:00 committed by Andrew Nelless
parent 61b489ab3d
commit 08a73218e6
4 changed files with 203 additions and 155 deletions

View File

@ -326,29 +326,11 @@ TCPSocket::init()
TCPSocket::EJobResult
TCPSocket::doRead()
{
try {
static UInt8 buffer[4096];
UInt8 buffer[4096];
memset(buffer, 0, sizeof(buffer));
int bytesRead = 0;
int status = 0;
size_t bytesRead = 0;
if (isSecure()) {
if (isSecureReady()) {
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
else if (status == 0) {
return kNew;
}
}
else {
return kRetry;
}
}
else {
bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer));
}
if (bytesRead > 0) {
bool wasEmpty = (m_inputBuffer.getSize() == 0);
@ -357,17 +339,8 @@ TCPSocket::doRead()
do {
m_inputBuffer.write(buffer, bytesRead);
if (isSecure() && isSecureReady()) {
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
}
else {
bytesRead = (int) ARCH->readSocket(m_socket, buffer, sizeof(buffer));
}
} while (bytesRead > 0 || status > 0);
bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer));
} while (bytesRead > 0);
// send input ready if input buffer was empty
if (wasEmpty) {
@ -386,17 +359,6 @@ TCPSocket::doRead()
m_readable = false;
return kNew;
}
}
catch (XArchNetworkDisconnected&) {
// stream hungup
sendEvent(m_events->forISocket().disconnected());
onDisconnected();
return kNew;
}
catch (XArchNetwork& e) {
// ignore other read error
LOG((CLOG_WARN "error reading socket: %s", e.what()));
}
return kRetry;
}
@ -404,92 +366,16 @@ TCPSocket::doRead()
TCPSocket::EJobResult
TCPSocket::doWrite()
{
static bool s_retry = false;
static int s_retrySize = 0;
static void* s_staticBuffer = NULL;
try {
// write data
int bufferSize = 0;
UInt32 bufferSize = 0;
int bytesWrote = 0;
int status = 0;
if (s_retry) {
bufferSize = s_retrySize;
}
else {
bufferSize = m_outputBuffer.getSize();
s_staticBuffer = malloc(bufferSize);
memcpy(s_staticBuffer, m_outputBuffer.peek(bufferSize), bufferSize);
}
const void* buffer = m_outputBuffer.peek(bufferSize);
bytesWrote = (UInt32)ARCH->writeSocket(m_socket, buffer, bufferSize);
if (bufferSize == 0) {
return kRetry;
}
if (isSecure()) {
if (isSecureReady()) {
status = secureWrite(s_staticBuffer, bufferSize, bytesWrote);
if (status > 0) {
s_retry = false;
bufferSize = 0;
free(s_staticBuffer);
s_staticBuffer = NULL;
}
else if (status < 0) {
return kBreak;
}
else if (status == 0) {
s_retry = true;
s_retrySize = bufferSize;
return kNew;
}
}
else {
return kRetry;
}
}
else {
bytesWrote = (UInt32)ARCH->writeSocket(m_socket, s_staticBuffer, bufferSize);
bufferSize = 0;
free(s_staticBuffer);
s_staticBuffer = NULL;
}
// discard written data
if (bytesWrote > 0) {
m_outputBuffer.pop(bytesWrote);
if (m_outputBuffer.getSize() == 0) {
sendEvent(m_events->forIStream().outputFlushed());
m_flushed = true;
m_flushed.broadcast();
return kNew;
}
}
}
catch (XArchNetworkShutdown&) {
// remote read end of stream hungup. our output side
// has therefore shutdown.
onOutputShutdown();
sendEvent(m_events->forIStream().outputShutdown());
if (!m_readable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
return kNew;
}
catch (XArchNetworkDisconnected&) {
// stream hungup
onDisconnected();
sendEvent(m_events->forISocket().disconnected());
return kNew;
}
catch (XArchNetwork& e) {
// other write error
LOG((CLOG_WARN "error writing socket: %s", e.what()));
onDisconnected();
sendEvent(m_events->forIStream().outputError());
sendEvent(m_events->forISocket().disconnected());
discardWrittenData(bytesWrote);
return kNew;
}
@ -550,6 +436,17 @@ TCPSocket::sendEvent(Event::Type type)
m_events->addEvent(Event(type, getEventTarget(), NULL));
}
void
TCPSocket::discardWrittenData(int bytesWrote)
{
m_outputBuffer.pop(bytesWrote);
if (m_outputBuffer.getSize() == 0) {
sendEvent(m_events->forIStream().outputFlushed());
m_flushed = true;
m_flushed.broadcast();
}
}
void
TCPSocket::onConnected()
{
@ -643,12 +540,51 @@ TCPSocket::serviceConnected(ISocketMultiplexerJob* job,
EJobResult result = kRetry;
if (write) {
try {
result = doWrite();
}
catch (XArchNetworkShutdown&) {
// remote read end of stream hungup. our output side
// has therefore shutdown.
onOutputShutdown();
sendEvent(m_events->forIStream().outputShutdown());
if (!m_readable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
result = kNew;
}
catch (XArchNetworkDisconnected&) {
// stream hungup
onDisconnected();
sendEvent(m_events->forISocket().disconnected());
result = kNew;
}
catch (XArchNetwork& e) {
// other write error
LOG((CLOG_WARN "error writing socket: %s", e.what()));
onDisconnected();
sendEvent(m_events->forIStream().outputError());
sendEvent(m_events->forISocket().disconnected());
result = kNew;
}
}
if (read && m_readable) {
try {
result = doRead();
}
catch (XArchNetworkDisconnected&) {
// stream hungup
sendEvent(m_events->forISocket().disconnected());
onDisconnected();
result = kNew;
}
catch (XArchNetwork& e) {
// ignore other read error
LOG((CLOG_WARN "error reading socket: %s", e.what()));
}
}
return result == kBreak ? NULL : result == kNew ? newJob() : job;
}

View File

@ -89,6 +89,7 @@ protected:
Mutex& getMutex() { return m_mutex; }
void sendEvent(Event::Type);
void discardWrittenData(int bytesWrote);
private:
void init();
@ -111,12 +112,12 @@ protected:
bool m_writable;
bool m_connected;
IEventQueue* m_events;
StreamBuffer m_inputBuffer;
StreamBuffer m_outputBuffer;
private:
Mutex m_mutex;
ArchSocket m_socket;
StreamBuffer m_inputBuffer;
StreamBuffer m_outputBuffer;
CondVar<bool> m_flushed;
SocketMultiplexer* m_socketMultiplexer;
};

View File

@ -140,6 +140,115 @@ SecureSocket::secureAccept()
getSocket(), isReadable(), isWritable()));
}
TCPSocket::EJobResult
SecureSocket::doRead()
{
static UInt8 buffer[4096];
memset(buffer, 0, sizeof(buffer));
int bytesRead = 0;
int status = 0;
if (isSecureReady()) {
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
else if (status == 0) {
return kNew;
}
}
else {
return kRetry;
}
if (bytesRead > 0) {
bool wasEmpty = (m_inputBuffer.getSize() == 0);
// slurp up as much as possible
do {
m_inputBuffer.write(buffer, bytesRead);
status = secureRead(buffer, sizeof(buffer), bytesRead);
if (status < 0) {
return kBreak;
}
} while (bytesRead > 0 || status > 0);
// send input ready if input buffer was empty
if (wasEmpty) {
sendEvent(m_events->forIStream().inputReady());
}
}
else {
// remote write end of stream hungup. our input side
// has therefore shutdown but don't flush our buffer
// since there's still data to be read.
sendEvent(m_events->forIStream().inputShutdown());
if (!m_writable && m_inputBuffer.getSize() == 0) {
sendEvent(m_events->forISocket().disconnected());
m_connected = false;
}
m_readable = false;
return kNew;
}
return kRetry;
}
TCPSocket::EJobResult
SecureSocket::doWrite()
{
static bool s_retry = false;
static int s_retrySize = 0;
static void* s_staticBuffer = NULL;
// write data
int bufferSize = 0;
int bytesWrote = 0;
int status = 0;
if (s_retry) {
bufferSize = s_retrySize;
}
else {
bufferSize = m_outputBuffer.getSize();
s_staticBuffer = malloc(bufferSize);
memcpy(s_staticBuffer, m_outputBuffer.peek(bufferSize), bufferSize);
}
if (bufferSize == 0) {
return kRetry;
}
if (isSecureReady()) {
status = secureWrite(s_staticBuffer, bufferSize, bytesWrote);
if (status > 0) {
s_retry = false;
bufferSize = 0;
free(s_staticBuffer);
s_staticBuffer = NULL;
}
else if (status < 0) {
return kBreak;
}
else if (status == 0) {
s_retry = true;
s_retrySize = bufferSize;
return kNew;
}
}
else {
return kRetry;
}
if (bytesWrote > 0) {
discardWrittenData(bytesWrote);
return kNew;
}
return kRetry;
}
int
SecureSocket::secureRead(void* buffer, int size, int& read)
{

View File

@ -48,6 +48,8 @@ public:
newJob();
void secureConnect();
void secureAccept();
EJobResult doRead();
EJobResult doWrite();
bool isReady() const { return m_secureReady; }
bool isFatal() const { return m_fatal; }
void isFatal(bool b) { m_fatal = b; }