diff --git a/http/CHTTPProtocol.cpp b/http/CHTTPProtocol.cpp index 4c52ca7d..0d00fb87 100644 --- a/http/CHTTPProtocol.cpp +++ b/http/CHTTPProtocol.cpp @@ -56,15 +56,86 @@ bool CHTTPUtil::CaselessCmp::operator()( return less(a, b); } +// +// CHTTPRequest +// + +CHTTPRequest::CHTTPRequest() +{ + // do nothing +} + +CHTTPRequest::~CHTTPRequest() +{ + // do nothing +} + +void CHTTPRequest::insertHeader( + const CString& name, const CString& value) +{ + CHeaderMap::iterator index = m_headerByName.find(name); + if (index != m_headerByName.end()) { + index->second->second = value; + } + else { + CHeaderList::iterator pos = m_headers.insert( + m_headers.end(), std::make_pair(name, value)); + m_headerByName.insert(std::make_pair(name, pos)); + } +} + +void CHTTPRequest::appendHeader( + const CString& name, const CString& value) +{ + CHeaderMap::iterator index = m_headerByName.find(name); + if (index != m_headerByName.end()) { + index->second->second += ","; + index->second->second += value; + } + else { + CHeaderList::iterator pos = m_headers.insert( + m_headers.end(), std::make_pair(name, value)); + m_headerByName.insert(std::make_pair(name, pos)); + } +} + +void CHTTPRequest::eraseHeader(const CString& name) +{ + CHeaderMap::iterator index = m_headerByName.find(name); + if (index != m_headerByName.end()) { + m_headers.erase(index->second); + } +} + +bool CHTTPRequest::isHeader(const CString& name) const +{ + return (m_headerByName.find(name) != m_headerByName.end()); +} + +CString CHTTPRequest::getHeader(const CString& name) const +{ + CHeaderMap::const_iterator index = m_headerByName.find(name); + if (index != m_headerByName.end()) { + return index->second->second; + } + else { + return CString(); + } +} + // // CHTTPProtocol // -CHTTPRequest* CHTTPProtocol::readRequest(IInputStream* stream) +CHTTPRequest* CHTTPProtocol::readRequest( + IInputStream* stream, UInt32 maxSize) { CString scratch; + // note if we should limit the request size + const bool checkSize = (maxSize > 0); + // parse request line by line CHTTPRequest* request = new CHTTPRequest; try { @@ -73,6 +144,12 @@ CHTTPRequest* CHTTPProtocol::readRequest(IInputStream* stream) // read request line. accept and discard leading empty lines. do { line = readLine(stream, scratch); + if (checkSize) { + if (line.size() + 2 > maxSize) { + throw XHTTP(413); + } + maxSize -= line.size() + 2; + } } while (line.empty()); // parse request line: @@ -109,12 +186,13 @@ CHTTPRequest* CHTTPProtocol::readRequest(IInputStream* stream) } // parse headers - readHeaders(stream, request, false, scratch); + readHeaders(stream, request, false, scratch, + checkSize ? &maxSize : NULL); // HTTP/1.1 requests must have a Host header if (request->m_majorVersion > 1 || (request->m_majorVersion == 1 && request->m_minorVersion >= 1)) { - if (request->m_headerIndexByName.count("Host") == 0) { + if (request->isHeader("Host") == 0) { log((CLOG_DEBUG1 "Host header missing")); throw XHTTP(400); } @@ -123,8 +201,8 @@ CHTTPRequest* CHTTPProtocol::readRequest(IInputStream* stream) // some methods may not have a body. ensure that the headers // that indicate the body length do not exist for those methods // and do exist for others. - if ((request->m_headerIndexByName.count("Transfer-Encoding") == 0 && - request->m_headerIndexByName.count("Content-Length") == 0) != + if ((request->isHeader("Transfer-Encoding") || + request->isHeader("Content-Length")) == (request->m_method == "GET" || request->m_method == "HEAD")) { log((CLOG_DEBUG1 "HTTP method (%s)/body mismatch", request->m_method.c_str())); @@ -136,13 +214,11 @@ CHTTPRequest* CHTTPProtocol::readRequest(IInputStream* stream) // 1. Transfer-Encoding indicates a "chunked" transfer // 2. Content-Length is present // Content-Length is ignored for "chunked" transfers. - CHTTPRequest::CHeaderMap::iterator index = request-> - m_headerIndexByName.find("Transfer-Encoding"); - if (index != request->m_headerIndexByName.end()) { + CString header; + if (!(header = request->getHeader("Transfer-Encoding")).empty()) { // we only understand "chunked" encodings - if (!CHTTPUtil::CaselessCmp::equal( - request->m_headers[index->second], "chunked")) { - log((CLOG_DEBUG1 "unsupported Transfer-Encoding %s", request->m_headers[index->second].c_str())); + if (!CHTTPUtil::CaselessCmp::equal(header, "chunked")) { + log((CLOG_DEBUG1 "unsupported Transfer-Encoding %s", header.c_str())); throw XHTTP(501); } @@ -150,36 +226,39 @@ CHTTPRequest* CHTTPProtocol::readRequest(IInputStream* stream) UInt32 oldSize; do { oldSize = request->m_body.size(); - request->m_body += readChunk(stream, scratch); + request->m_body += readChunk(stream, scratch, + checkSize ? &maxSize : NULL); } while (request->m_body.size() != oldSize); // read footer - readHeaders(stream, request, true, scratch); + readHeaders(stream, request, true, scratch, + checkSize ? &maxSize : NULL); // remove "chunked" from Transfer-Encoding and set the // Content-Length. - // FIXME - // FIXME -- note that just deleting Transfer-Encoding will - // mess up indices in m_headerIndexByName, and replacing - // it with Content-Length could lead to two of those. + std::ostringstream s; + s << std::dec << request->m_body.size(); + request->eraseHeader("Transfer-Encoding"); + request->insertHeader("Content-Length", s.str()); } - else if ((index = request->m_headerIndexByName. - find("Content-Length")) != - request->m_headerIndexByName.end()) { - // FIXME -- check for overly-long requests - + else if (!(header = request->getHeader("Content-Length")).empty()) { // parse content-length UInt32 length; { - std::istringstream s(request->m_headers[index->second]); + std::istringstream s(header); s.exceptions(std::ios::goodbit); s >> length; if (!s) { - log((CLOG_DEBUG1 "cannot parse Content-Length", request->m_headers[index->second].c_str())); + log((CLOG_DEBUG1 "cannot parse Content-Length", header.c_str())); throw XHTTP(400); } } + // check against expected size + if (checkSize && length > maxSize) { + throw XHTTP(413); + } + // use content length request->m_body = readBlock(stream, length, scratch); if (request->m_body.size() != length) { @@ -287,13 +366,11 @@ bool CHTTPProtocol::parseFormData( static const char quote[] = "\""; // find the Content-Type header - CHTTPRequest::CHeaderMap::const_iterator contentTypeIndex = - request.m_headerIndexByName.find("Content-Type"); - if (contentTypeIndex == request.m_headerIndexByName.end()) { + const CString contentType = request.getHeader("Content-Type"); + if (contentType.empty()) { // missing required Content-Type header return false; } - const CString contentType = request.m_headers[contentTypeIndex->second]; // parse type CString::const_iterator index = std::search( @@ -335,8 +412,7 @@ bool CHTTPProtocol::parseFormData( if (body.size() >= partIndex + 2 && body[partIndex ] == '-' && body[partIndex + 1] == '-') { - // found last part. success if there's no trailing data. - // FIXME -- check for trailing data (other than a single CRLF) + // found last part. ignore trailing data, if any. return true; } @@ -500,7 +576,8 @@ CString CHTTPProtocol::readBlock( CString CHTTPProtocol::readChunk( IInputStream* stream, - CString& tmpBuffer) + CString& tmpBuffer, + UInt32* maxSize) { CString line; @@ -522,8 +599,15 @@ CString CHTTPProtocol::readChunk( return CString(); } + // check size + if (maxSize != NULL) { + if (line.size() + 2 + size + 2 > *maxSize) { + throw XHTTP(413); + } + maxSize -= line.size() + 2 + size + 2; + } + // read size bytes - // FIXME -- check for overly-long requests CString data = readBlock(stream, size, tmpBuffer); if (data.size() != size) { log((CLOG_DEBUG1 "expected/actual chunk size mismatch", size, data.size())); @@ -544,27 +628,36 @@ void CHTTPProtocol::readHeaders( IInputStream* stream, CHTTPRequest* request, bool isFooter, - CString& tmpBuffer) + CString& tmpBuffer, + UInt32* maxSize) { // parse headers. done with headers when we get a blank line. + CString name; CString line = readLine(stream, tmpBuffer); while (!line.empty()) { + // check size + if (maxSize != NULL) { + if (line.size() + 2 > *maxSize) { + throw XHTTP(413); + } + *maxSize -= line.size() + 2; + } + // if line starts with space or tab then append it to the // previous header. if there is no previous header then // throw. if (line[0] == ' ' || line[0] == '\t') { - if (request->m_headers.size() == 0) { + if (name.empty()) { log((CLOG_DEBUG1 "first header is a continuation")); throw XHTTP(400); } - request->m_headers.back() += ","; - request->m_headers.back() == line; + request->appendHeader(name, line); } // line should have the form: :[] else { // parse - CString name, value; + CString value; std::istringstream s(line); s.exceptions(std::ios::goodbit); std::getline(s, name, ':'); @@ -577,29 +670,14 @@ void CHTTPProtocol::readHeaders( // check validity of name if (isFooter) { // FIXME -- only certain names are allowed in footers + // but which ones? } - // check if we've seen this header before - CHTTPRequest::CHeaderMap::iterator index = - request->m_headerIndexByName.find(name); - if (index == request->m_headerIndexByName.end()) { - // it's a new header - request->m_headerIndexByName.insert(std::make_pair(name, - request->m_headers.size())); - request->m_headers.push_back(value); - } - else { - // it's an existing header. append value to previous - // header, separated by a comma. - request->m_headers[index->second] += ','; - request->m_headers[index->second] += value; - } + request->appendHeader(name, value); } // next header line = readLine(stream, tmpBuffer); - - // FIXME -- should check for overly-long requests } } diff --git a/http/CHTTPProtocol.h b/http/CHTTPProtocol.h index 357c3175..316d1fd4 100644 --- a/http/CHTTPProtocol.h +++ b/http/CHTTPProtocol.h @@ -3,6 +3,7 @@ #include "BasicTypes.h" #include "CString.h" +#include "stdlist.h" #include "stdmap.h" #include "stdvector.h" @@ -25,19 +26,52 @@ public: class CHTTPRequest { public: - typedef std::map CHeaderMap; - typedef std::vector CHeaderList; + typedef std::list > CHeaderList; + typedef std::map CHeaderMap; + typedef CHeaderList::const_iterator const_iterator; + CHTTPRequest(); + ~CHTTPRequest(); + + // manipulators + + // add a header by name. replaces existing header, if any. + // headers are sent in the order they're inserted. replacing + // a header does not change its original position in the order. + void insertHeader(const CString& name, const CString& value); + + // append a header. equivalent to insertHeader() if the header + // doesn't exist, otherwise it appends a comma and the value to + // the existing header. + void appendHeader(const CString& name, const CString& value); + + // remove a header by name. does nothing if no such header. + void eraseHeader(const CString& name); + + // accessors + + // returns true iff the header exists + bool isHeader(const CString& name) const; + + // get a header by name. returns the empty string if no such header. + CString getHeader(const CString& name) const; + + // get iterator over all headers in the order they were added + const_iterator begin() const { return m_headers.begin(); } + const_iterator end() const { return m_headers.end(); } + +public: + // note -- these members are public for convenience CString m_method; CString m_uri; SInt32 m_majorVersion; SInt32 m_minorVersion; - - CHeaderList m_headers; - CHeaderMap m_headerIndexByName; - CString m_body; - // FIXME -- need parts-of-body for POST messages + +private: + CHeaderList m_headers; + CHeaderMap m_headerByName; }; class CHTTPReply { @@ -59,9 +93,11 @@ class CHTTPProtocol { public: // read and parse an HTTP request. result is returned in a // CHTTPRequest which the client must delete. throws an - // XHTTP if there was a parse error. throws an XIO - // exception if there was a read error. - static CHTTPRequest* readRequest(IInputStream*); + // XHTTP if there was a parse error. throws an XIO exception + // if there was a read error. if maxSize is greater than + // zero and the request is larger than maxSize bytes then + // throws XHTTP(413). + static CHTTPRequest* readRequest(IInputStream*, UInt32 maxSize = 0); // send an HTTP reply on the stream static void reply(IOutputStream*, CHTTPReply&); @@ -77,10 +113,12 @@ private: static CString readLine(IInputStream*, CString& tmpBuffer); static CString readBlock(IInputStream*, UInt32 numBytes, CString& tmpBuffer); - static CString readChunk(IInputStream*, CString& tmpBuffer); + static CString readChunk(IInputStream*, CString& tmpBuffer, + UInt32* maxSize); static void readHeaders(IInputStream*, CHTTPRequest*, bool isFooter, - CString& tmpBuffer); + CString& tmpBuffer, + UInt32* maxSize); static bool isValidToken(const CString&); }; diff --git a/server/CHTTPServer.cpp b/server/CHTTPServer.cpp index 05378d8e..d85b0978 100644 --- a/server/CHTTPServer.cpp +++ b/server/CHTTPServer.cpp @@ -14,6 +14,11 @@ // CHTTPServer // +// maximum size of an HTTP request. this should be large enough to +// handle any reasonable request but small enough to prevent a +// malicious client from causing us to use too much memory. +const UInt32 CHTTPServer::s_maxRequestSize = 32768; + CHTTPServer::CHTTPServer(CServer* server) : m_server(server) { // do nothing @@ -31,7 +36,8 @@ void CHTTPServer::processRequest(ISocket* socket) CHTTPRequest* request = NULL; try { // parse request - request = CHTTPProtocol::readRequest(socket->getInputStream()); + request = CHTTPProtocol::readRequest( + socket->getInputStream(), s_maxRequestSize); if (request == NULL) { throw XHTTP(400); } diff --git a/server/CHTTPServer.h b/server/CHTTPServer.h index 50a5708b..7c0cf311 100644 --- a/server/CHTTPServer.h +++ b/server/CHTTPServer.h @@ -96,6 +96,7 @@ protected: private: CServer* m_server; + static const UInt32 s_maxRequestSize; }; #endif