From 8c846470205625383b5e586c6e0d1d43750551ce Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Sat, 21 Aug 2021 22:48:07 -0700 Subject: [PATCH 1/7] Add deflate/inflate helper classes --- src/deflate.cpp | 125 ++++++++++++++++++++++++++++++++++++++++++++++++ src/deflate.h | 49 +++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 src/deflate.cpp create mode 100644 src/deflate.h diff --git a/src/deflate.cpp b/src/deflate.cpp new file mode 100644 index 00000000..36f8b86f --- /dev/null +++ b/src/deflate.cpp @@ -0,0 +1,125 @@ +#include "deflate.h" +#include +#include "utils.h" + +namespace deflator { + +typedef int(*flate_func)(z_stream* strm, int flush); + +// Generic function for driving I/O loop for inflate/deflate +int flate(flate_func func, z_stream* strm, char* data, size_t data_len, std::vector& output) { + + strm->next_in = reinterpret_cast(data); + strm->avail_in = data_len; + + while (strm->avail_in) { + // Ensure enough room is allocated on the output to receive 1024 more bytes + const size_t CHUNK_SIZE = 1024; + output.reserve(output.size() + CHUNK_SIZE); + + strm->next_out = reinterpret_cast(&output[output.size()]); + strm->avail_out = output.capacity() - output.size(); + + size_t total_out_start = strm->total_out; + int error = func(strm, Z_SYNC_FLUSH); + if (error != Z_OK && error != Z_STREAM_END) { + // std::cerr << strm->msg << "\n"; + return error; + } + size_t bytes_written = strm->total_out - total_out_start; + std::copy(&output[output.size()], &output[output.size()] + bytes_written, + std::back_inserter(output)); + } + + return Z_OK; +} + +Deflator::Deflator() { + _stream = {0}; + _state = DeflatorStateStart; +} + +Deflator::~Deflator() { + if (_state == DeflatorStateReady) { + int error = deflateEnd(&_stream); + if (error != Z_OK) { + debug_log("deflateEnd failed", LOG_WARN); + } + } + _state = DeflatorStateEnded; +} + +int Deflator::init(DeflateMode mode, int level, int windowBits, int memLevel, int strategy) { + if (_state != DeflatorStateStart) { + throw std::runtime_error("Deflator.init() was called twice"); + } + + if (windowBits < 9 || windowBits > 15) { + throw std::runtime_error("Deflator.init() expects 9 <= windowBits <= 15"); + } + + if (mode == DeflateModeRaw) { + windowBits *= -1; + } else if (mode == DeflateModeGzip) { + windowBits += 16; + } + + int result = deflateInit2(&_stream, level, Z_DEFLATED, windowBits, memLevel, strategy); + if (result == Z_OK) { + _state = DeflatorStateReady; + } + return result; +} + +int Deflator::deflate(char* data, size_t data_len, std::vector& output) { + if (_state != DeflatorStateReady) { + throw std::runtime_error("Deflator.init() must be called before deflate()"); + } + return flate(::deflate, &_stream, data, data_len, output); +} + +Inflator::Inflator() { + _stream = {0}; + _state = DeflatorStateStart; +} + +Inflator::~Inflator() { + if (_state == DeflatorStateReady) { + int error = inflateEnd(&_stream); + if (error != Z_OK) { + debug_log("inflateEnd failed", LOG_WARN); + } + } + _state = DeflatorStateEnded; +} + +int Inflator::init(DeflateMode mode, int windowBits) { + if (_state != DeflatorStateStart) { + throw std::runtime_error("Inflator.init() was called twice"); + } + + if (windowBits < 8 || windowBits > 15) { + throw std::runtime_error("Inflator.init() expects 9 <= windowBits <= 15"); + } + + if (mode == DeflateModeRaw) { + windowBits *= -1; + } else if (mode == DeflateModeGzip) { + windowBits += 16; + } + + int result = inflateInit2(&_stream, windowBits); + if (result == Z_OK) { + _state = DeflatorStateReady; + } + return result; +} + +int Inflator::inflate(char* data, size_t data_len, std::vector& output) { + if (_state != DeflatorStateReady) { + throw std::runtime_error("Inflator.init() must be called before deflate()"); + } + return flate(::inflate, &_stream, data, data_len, output); +} + +} // namespace diff --git a/src/deflate.h b/src/deflate.h new file mode 100644 index 00000000..5be15ad9 --- /dev/null +++ b/src/deflate.h @@ -0,0 +1,49 @@ +#ifndef DEFLATE_H +#define DEFLATE_H + +#include +#include + +namespace deflator { + +enum DeflateMode { + DeflateModeZlib, + DeflateModeRaw, + DeflateModeGzip, +}; + +enum DeflatorState { + DeflatorStateStart, + DeflatorStateReady, + DeflatorStateEnded +}; + +class Deflator { +private: + z_stream _stream; + DeflatorState _state; + +public: + Deflator(); + ~Deflator(); + int init(DeflateMode mode, int level = Z_DEFAULT_COMPRESSION, + int windowBits = 13, int memLevel = 8, + int strategy = Z_DEFAULT_STRATEGY); + int deflate(char* data, size_t data_len, std::vector& output); +}; + +class Inflator { +private: + z_stream _stream; + DeflatorState _state; + +public: + Inflator(); + ~Inflator(); + int init(DeflateMode mode, int windowBits = 15); + int inflate(char* data, size_t data_len, std::vector& output); +}; + +} + +#endif // DEFLATE_H From 0b224f1370f11e484044e360d313a365811cce8d Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 23 Aug 2021 23:41:38 -0700 Subject: [PATCH 2/7] Fix deflator logic --- src/deflate.cpp | 26 +++++++++++++++----------- src/deflate.h | 4 ++-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/deflate.cpp b/src/deflate.cpp index 36f8b86f..8e6c56bc 100644 --- a/src/deflate.cpp +++ b/src/deflate.cpp @@ -7,28 +7,32 @@ namespace deflator { typedef int(*flate_func)(z_stream* strm, int flush); // Generic function for driving I/O loop for inflate/deflate -int flate(flate_func func, z_stream* strm, char* data, size_t data_len, std::vector& output) { +int flate(flate_func func, z_stream* strm, const char* data, size_t data_len, std::vector& output) { - strm->next_in = reinterpret_cast(data); + strm->next_in = reinterpret_cast(const_cast(data)); strm->avail_in = data_len; + size_t orig_output_size = output.size(); + while (strm->avail_in) { // Ensure enough room is allocated on the output to receive 1024 more bytes const size_t CHUNK_SIZE = 1024; - output.reserve(output.size() + CHUNK_SIZE); + output.resize(output.size() + CHUNK_SIZE); - strm->next_out = reinterpret_cast(&output[output.size()]); - strm->avail_out = output.capacity() - output.size(); + strm->next_out = reinterpret_cast(&output[output.size() - CHUNK_SIZE]); + strm->avail_out = CHUNK_SIZE; - size_t total_out_start = strm->total_out; int error = func(strm, Z_SYNC_FLUSH); if (error != Z_OK && error != Z_STREAM_END) { // std::cerr << strm->msg << "\n"; return error; } - size_t bytes_written = strm->total_out - total_out_start; - std::copy(&output[output.size()], &output[output.size()] + bytes_written, - std::back_inserter(output)); + } + size_t bytes_written = strm->total_out; + size_t bytes_allocated = output.size() - orig_output_size; + size_t bytes_to_erase = bytes_allocated - bytes_written; + if (bytes_to_erase > 0) { + output.erase(output.end() - bytes_to_erase, output.end()); } return Z_OK; @@ -71,7 +75,7 @@ int Deflator::init(DeflateMode mode, int level, int windowBits, int memLevel, in return result; } -int Deflator::deflate(char* data, size_t data_len, std::vector& output) { +int Deflator::deflate(const char* data, size_t data_len, std::vector& output) { if (_state != DeflatorStateReady) { throw std::runtime_error("Deflator.init() must be called before deflate()"); } @@ -115,7 +119,7 @@ int Inflator::init(DeflateMode mode, int windowBits) { return result; } -int Inflator::inflate(char* data, size_t data_len, std::vector& output) { +int Inflator::inflate(const char* data, size_t data_len, std::vector& output) { if (_state != DeflatorStateReady) { throw std::runtime_error("Inflator.init() must be called before deflate()"); } diff --git a/src/deflate.h b/src/deflate.h index 5be15ad9..02addbf2 100644 --- a/src/deflate.h +++ b/src/deflate.h @@ -29,7 +29,7 @@ class Deflator { int init(DeflateMode mode, int level = Z_DEFAULT_COMPRESSION, int windowBits = 13, int memLevel = 8, int strategy = Z_DEFAULT_STRATEGY); - int deflate(char* data, size_t data_len, std::vector& output); + int deflate(const char* data, size_t data_len, std::vector& output); }; class Inflator { @@ -41,7 +41,7 @@ class Inflator { Inflator(); ~Inflator(); int init(DeflateMode mode, int windowBits = 15); - int inflate(char* data, size_t data_len, std::vector& output); + int inflate(const char* data, size_t data_len, std::vector& output); }; } From 6ab9394dca460b9807a6d015acebf51e9e2b1a2e Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 23 Aug 2021 23:42:14 -0700 Subject: [PATCH 3/7] wip --- src/constants.h | 19 +++++++++++++ src/websockets-base.cpp | 3 +- src/websockets-base.h | 11 ++++++-- src/websockets-hixie76.cpp | 7 +++-- src/websockets-hixie76.h | 5 ++-- src/websockets-hybi03.cpp | 3 +- src/websockets-hybi03.h | 5 ++-- src/websockets-ietf.cpp | 45 ++++++++++++++++++++++++++++- src/websockets-ietf.h | 3 +- src/websockets.cpp | 58 +++++++++++++++++++++++++++++++++----- src/websockets.h | 18 +++++++++--- 11 files changed, 153 insertions(+), 24 deletions(-) diff --git a/src/constants.h b/src/constants.h index a7bc98bb..a582470b 100644 --- a/src/constants.h +++ b/src/constants.h @@ -54,4 +54,23 @@ static inline std::string trim(const std::string &s) { return s.substr(start, end-start); } +static inline std::vector split(const std::string& s, const std::string& delim) { + std::vector results; + size_t pos = 0; + while (true) { + size_t i = s.find(delim, pos); + if (i == std::string::npos) { + break; + } + if (i != pos) { + results.push_back(s.substr(pos, i - pos)); + } + pos = i + 1; + } + if (pos != s.length()) { + results.push_back(s.substr(pos)); + } + return results; +} + #endif // CONSTANTS_H diff --git a/src/websockets-base.cpp b/src/websockets-base.cpp index 29b92288..ef0c1835 100644 --- a/src/websockets-base.cpp +++ b/src/websockets-base.cpp @@ -25,7 +25,7 @@ void swapByteOrder(unsigned char* pStart, unsigned char* pEnd) { } void WebSocketProto::createFrameHeader( - Opcode opcode, bool mask, size_t payloadSize, int32_t maskingKey, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pData[MAX_HEADER_BYTES], size_t* pLen) const { unsigned char* pBuf = (unsigned char*)pData; @@ -35,6 +35,7 @@ void WebSocketProto::createFrameHeader( pBuf[0] = toFin(true) << 7 | // FIN; always true + (rsv1 << 6) | encodeOpcode(opcode); pBuf[1] = mask ? 1 << 7 : 0; if (payloadSize_64 <= 125) { diff --git a/src/websockets-base.h b/src/websockets-base.h index 8781f5f3..6c5ac0db 100644 --- a/src/websockets-base.h +++ b/src/websockets-base.h @@ -5,6 +5,12 @@ #include "constants.h" +struct WebSocketConnectionContext { + bool permessageDeflate; + int clientMaxWindowBits; + int serverMaxWindowBits; +}; + class WebSocketProto { public: @@ -22,9 +28,10 @@ class WebSocketProto { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const = 0; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const = 0; - void createFrameHeader(Opcode opcode, bool mask, size_t payloadSize, + void createFrameHeader(Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pData[MAX_HEADER_BYTES], size_t* pLen) const; diff --git a/src/websockets-hixie76.cpp b/src/websockets-hixie76.cpp index 5d0b9d83..960abb12 100644 --- a/src/websockets-hixie76.cpp +++ b/src/websockets-hixie76.cpp @@ -5,13 +5,14 @@ void WSHixie76Parser::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { _hybi03.handshake(url, requestHeaders, ppData, pLen, responseHeaders, - pResponse); + pResponse, pContext); } void WSHixie76Parser::createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen) const { diff --git a/src/websockets-hixie76.h b/src/websockets-hixie76.h index c8cbb0e9..332e89c4 100644 --- a/src/websockets-hixie76.h +++ b/src/websockets-hixie76.h @@ -38,10 +38,11 @@ class WSHixie76Parser : public WSParser { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; void createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen diff --git a/src/websockets-hybi03.cpp b/src/websockets-hybi03.cpp index c579e25c..439fe08c 100644 --- a/src/websockets-hybi03.cpp +++ b/src/websockets-hybi03.cpp @@ -54,7 +54,8 @@ void WebSocketProto_HyBi03::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { assert(*pLen >= 8); diff --git a/src/websockets-hybi03.h b/src/websockets-hybi03.h index 3cc301ae..db68224d 100644 --- a/src/websockets-hybi03.h +++ b/src/websockets-hybi03.h @@ -16,9 +16,10 @@ class WebSocketProto_HyBi03 : public WebSocketProto { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; - void createFrameHeader(Opcode opcode, bool mask, size_t payloadSize, + void createFrameHeader(Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pData[MAX_HEADER_BYTES], size_t* pLen) const; diff --git a/src/websockets-ietf.cpp b/src/websockets-ietf.cpp index f97101c3..1964df3a 100644 --- a/src/websockets-ietf.cpp +++ b/src/websockets-ietf.cpp @@ -1,10 +1,36 @@ #include "websockets-ietf.h" +#include +#include + #include "utils.h" #include "sha1/sha1.h" #include "base64/base64.hpp" +struct ExtensionInfo { + std::string extension; + std::map params; +}; + +ExtensionInfo parseExtensionInfo(const std::string str) { + std::vector parts = split(str, ";"); + std::string extension = parts[0]; + std::map params; + for (size_t i = 1; i < parts.size(); i++) { + std::vector param = split(parts[i], "="); + params[param[0]] = param.size() > 1 ? param[1] : ""; + } + ExtensionInfo result; + result.extension = extension; + result.params = params; + return result; +} + +std::vector splitExtensionsHeader(const std::string& header) { + return split(header, ","); +} + bool WebSocketProto_IETF::canHandle(const RequestHeaders& requestHeaders, const char* pData, size_t len) const { @@ -17,7 +43,8 @@ void WebSocketProto_IETF::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { std::string key = requestHeaders.at("sec-websocket-key"); @@ -37,6 +64,22 @@ void WebSocketProto_IETF::handshake(const std::string& url, std::pair("Upgrade", "websocket")); pResponseHeaders->push_back( std::pair("Sec-WebSocket-Accept", response)); + + auto swe = requestHeaders.find("sec-websocket-extensions"); + if (swe != requestHeaders.end()) { + auto extensions = split(swe->second, ","); + std::vector extInfos; + std::transform(extensions.begin(), extensions.end(), + std::back_inserter(extInfos), + parseExtensionInfo); + + for (auto pos = extInfos.begin(); pos != extInfos.end(); pos++) { + if (trim(pos->extension) == "permessage-deflate") { + pResponseHeaders->push_back(std::make_pair("Sec-WebSocket-Extensions", "permessage-deflate")); + pContext->permessageDeflate = true; + } + } + } } bool WebSocketProto_IETF::isFin(uint8_t firstBit) const { diff --git a/src/websockets-ietf.h b/src/websockets-ietf.h index dd166c7a..69204b35 100644 --- a/src/websockets-ietf.h +++ b/src/websockets-ietf.h @@ -16,7 +16,8 @@ class WebSocketProto_IETF : public WebSocketProto { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; bool isFin(uint8_t firstBit) const; uint8_t toFin(bool isFin) const; diff --git a/src/websockets.cpp b/src/websockets.cpp index f5ece8f2..199edbb7 100644 --- a/src/websockets.cpp +++ b/src/websockets.cpp @@ -45,6 +45,7 @@ bool WSHyBiFrameHeader::isHeaderComplete() const { WSFrameHeaderInfo WSHyBiFrameHeader::info() const { WSFrameHeaderInfo inf; inf.fin = fin(); + inf.rsv1 = rsv1(); inf.opcode = opcode(); inf.hasLength = true; inf.masked = masked(); @@ -59,6 +60,9 @@ WSFrameHeaderInfo WSHyBiFrameHeader::info() const { bool WSHyBiFrameHeader::fin() const { return _pProto->isFin(read(0, 1)); } +bool WSHyBiFrameHeader::rsv1() const { + return read(1, 1); +} Opcode WSHyBiFrameHeader::opcode() const { uint8_t oc = read(4, 4); return _pProto->decodeOpcode(oc); @@ -140,19 +144,20 @@ void WSHyBiParser::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { ASSERT_BACKGROUND_THREAD() _pProto->handshake(url, requestHeaders, ppData, pLen, pResponseHeaders, - pResponse); + pResponse, pContext); } void WSHyBiParser::createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen ) const { - _pProto->createFrameHeader(opcode, mask, payloadSize, maskingKey, + _pProto->createFrameHeader(opcode, rsv1, mask, payloadSize, maskingKey, pHeaderData, pHeaderLen); } @@ -258,7 +263,20 @@ void WebSocketConnection::handshake(const std::string& url, if (_connState == WS_CLOSED) return; _pParser->handshake(url, requestHeaders, ppData, pLen, pResponseHeaders, - pResponse); + pResponse, &_context); + + if (_context.permessageDeflate) { + int error; + // TODO: Handle errors + error = _inflator.init(deflator::DeflateModeRaw); + if (error != Z_OK) { + debug_log("Failed to init inflator", LOG_ERROR); + } + error = _deflator.init(deflator::DeflateModeRaw, Z_DEFAULT_COMPRESSION, 9); + if (error != Z_OK) { + debug_log("Failed to init deflator", LOG_ERROR); + } + } } void WebSocketConnection::sendWSMessage(Opcode opcode, const char* pData, size_t length) { @@ -271,7 +289,23 @@ void WebSocketConnection::sendWSMessage(Opcode opcode, const char* pData, size_t size_t headerLength = 0; size_t footerLength = 0; - _pParser->createFrameHeaderFooter(opcode, false, length, 0, + std::vector deflated(0); + bool deflate = _context.permessageDeflate; + if (deflate) { + // TODO: Handle errors + int error = _deflator.deflate(pData, length, deflated); + if (error != Z_OK) { + debug_log("An error occurred during deflate", LOG_ERROR); + } + deflated.pop_back(); + deflated.pop_back(); + deflated.pop_back(); + deflated.pop_back(); + pData = safe_vec_addr(deflated); + length = deflated.size(); + } + + _pParser->createFrameHeaderFooter(opcode, deflate, false, length, 0, safe_vec_addr(header), &headerLength, safe_vec_addr(footer), &footerLength); header.resize(headerLength); @@ -375,7 +409,17 @@ void WebSocketConnection::onFrameComplete() { } case Text: case Binary: { - _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(_payload), _payload.size()); + if (_context.permessageDeflate && _header.rsv1) { + _payload.push_back(0); + _payload.push_back(0); + _payload.push_back(0xFF); + _payload.push_back(0xFF); + std::vector inflated(0); + _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated); + _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(inflated), inflated.size()); + } else { + _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(_payload), _payload.size()); + } break; } case Close: { diff --git a/src/websockets.h b/src/websockets.h index ce26ab68..6010513f 100644 --- a/src/websockets.h +++ b/src/websockets.h @@ -9,6 +9,7 @@ #include #include "utils.h" +#include "deflate.h" #include "thread.h" #include "constants.h" #include "websockets-base.h" @@ -17,6 +18,7 @@ class WSFrameHeaderInfo { public: bool fin; + bool rsv1; Opcode opcode; bool masked; std::vector maskingKey; @@ -56,6 +58,9 @@ class WSHyBiFrameHeader { private: bool fin() const; + bool rsv1() const; + bool rsv2() const; + bool rsv3() const; Opcode opcode() const; bool masked() const; void maskingKey(uint8_t key[4]) const; @@ -92,10 +97,11 @@ class WSParser { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const = 0; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const = 0; virtual void createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen @@ -125,10 +131,11 @@ class WSHyBiParser : public WSParser { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; void createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen @@ -172,6 +179,9 @@ class WebSocketConnection : WSParserCallbacks, NoCopy { std::vector _incompleteContentPayload; std::vector _payload; uv_timer_t* _pPingTimer; + WebSocketConnectionContext _context; + deflator::Deflator _deflator; + deflator::Inflator _inflator; public: WebSocketConnection( From e7f4c167ece70d2304b66f47121eeda2a9387c72 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 30 Aug 2021 17:52:29 -0700 Subject: [PATCH 4/7] Working permessage-deflate, minus some params --- src/deflate.cpp | 52 ++++++----- src/websockets-base.h | 10 +++ src/websockets-ietf.cpp | 42 +-------- src/websockets.cpp | 11 ++- src/websockets.h | 4 +- src/wse-permessage-deflate.cpp | 155 +++++++++++++++++++++++++++++++++ src/wse-permessage-deflate.h | 17 ++++ 7 files changed, 228 insertions(+), 63 deletions(-) create mode 100644 src/wse-permessage-deflate.cpp create mode 100644 src/wse-permessage-deflate.h diff --git a/src/deflate.cpp b/src/deflate.cpp index 8e6c56bc..c68c83cf 100644 --- a/src/deflate.cpp +++ b/src/deflate.cpp @@ -1,5 +1,6 @@ #include "deflate.h" #include +#include #include "utils.h" namespace deflator { @@ -8,32 +9,37 @@ typedef int(*flate_func)(z_stream* strm, int flush); // Generic function for driving I/O loop for inflate/deflate int flate(flate_func func, z_stream* strm, const char* data, size_t data_len, std::vector& output) { + int error; + const size_t CHUNK_SIZE = 256; + unsigned char temp[CHUNK_SIZE]; + memset(temp, 0, CHUNK_SIZE); + strm->next_in = reinterpret_cast(const_cast(data)); strm->avail_in = data_len; - size_t orig_output_size = output.size(); - - while (strm->avail_in) { - // Ensure enough room is allocated on the output to receive 1024 more bytes - const size_t CHUNK_SIZE = 1024; - output.resize(output.size() + CHUNK_SIZE); - - strm->next_out = reinterpret_cast(&output[output.size() - CHUNK_SIZE]); + do { + strm->next_out = temp; strm->avail_out = CHUNK_SIZE; + error = func(strm, Z_NO_FLUSH); + if (error != Z_OK && error != Z_BUF_ERROR) { + debug_log("flate failed", LOG_INFO); + return error; + } + // This might not copy any data; sometimes deflate() just populates buffers + output.insert(output.end(), (char*)temp, (char*)(temp + CHUNK_SIZE - strm->avail_out)); + } while (strm->avail_out == 0 || strm->avail_in > 0); - int error = func(strm, Z_SYNC_FLUSH); - if (error != Z_OK && error != Z_STREAM_END) { - // std::cerr << strm->msg << "\n"; + do { + strm->next_out = temp; + strm->avail_out = CHUNK_SIZE; + error = func(strm, Z_SYNC_FLUSH); + if (error != Z_OK && error != Z_BUF_ERROR) { + debug_log("flate failed", LOG_INFO); return error; } - } - size_t bytes_written = strm->total_out; - size_t bytes_allocated = output.size() - orig_output_size; - size_t bytes_to_erase = bytes_allocated - bytes_written; - if (bytes_to_erase > 0) { - output.erase(output.end() - bytes_to_erase, output.end()); - } + output.insert(output.end(), (char*)temp, (char*)(temp + CHUNK_SIZE - strm->avail_out)); + } while (strm->avail_out == 0); return Z_OK; } @@ -46,7 +52,9 @@ Deflator::Deflator() { Deflator::~Deflator() { if (_state == DeflatorStateReady) { int error = deflateEnd(&_stream); - if (error != Z_OK) { + if (error == Z_STREAM_ERROR) { + // deflateEnd can return other errors, but they're nothing to worry about. + // https://stackoverflow.com/a/19816633/139922 debug_log("deflateEnd failed", LOG_WARN); } } @@ -90,7 +98,9 @@ Inflator::Inflator() { Inflator::~Inflator() { if (_state == DeflatorStateReady) { int error = inflateEnd(&_stream); - if (error != Z_OK) { + if (error == Z_STREAM_ERROR) { + // inflateEnd can return other errors, but they're nothing to worry about. + // https://stackoverflow.com/a/19816633/139922 debug_log("inflateEnd failed", LOG_WARN); } } @@ -121,7 +131,7 @@ int Inflator::init(DeflateMode mode, int windowBits) { int Inflator::inflate(const char* data, size_t data_len, std::vector& output) { if (_state != DeflatorStateReady) { - throw std::runtime_error("Inflator.init() must be called before deflate()"); + throw std::runtime_error("Inflator.init() must be called before inflate()"); } return flate(::inflate, &_stream, data, data_len, output); } diff --git a/src/websockets-base.h b/src/websockets-base.h index 6c5ac0db..f6775062 100644 --- a/src/websockets-base.h +++ b/src/websockets-base.h @@ -6,9 +6,19 @@ #include "constants.h" struct WebSocketConnectionContext { + WebSocketConnectionContext() : + permessageDeflate(false), + clientMaxWindowBits(-1), + serverMaxWindowBits(-1), + clientNoContextTakeover(false), + serverNoContextTakeover(false) { + } + bool permessageDeflate; int clientMaxWindowBits; int serverMaxWindowBits; + bool clientNoContextTakeover; + bool serverNoContextTakeover; }; class WebSocketProto { diff --git a/src/websockets-ietf.cpp b/src/websockets-ietf.cpp index 1964df3a..8290bc7e 100644 --- a/src/websockets-ietf.cpp +++ b/src/websockets-ietf.cpp @@ -8,35 +8,15 @@ #include "sha1/sha1.h" #include "base64/base64.hpp" -struct ExtensionInfo { - std::string extension; - std::map params; -}; - -ExtensionInfo parseExtensionInfo(const std::string str) { - std::vector parts = split(str, ";"); - std::string extension = parts[0]; - std::map params; - for (size_t i = 1; i < parts.size(); i++) { - std::vector param = split(parts[i], "="); - params[param[0]] = param.size() > 1 ? param[1] : ""; - } - ExtensionInfo result; - result.extension = extension; - result.params = params; - return result; -} - -std::vector splitExtensionsHeader(const std::string& header) { - return split(header, ","); -} +#include "wse-permessage-deflate.h" bool WebSocketProto_IETF::canHandle(const RequestHeaders& requestHeaders, const char* pData, size_t len) const { return requestHeaders.find("upgrade") != requestHeaders.end() && strcasecmp(requestHeaders.at("upgrade").c_str(), "websocket") == 0 && - requestHeaders.find("sec-websocket-key") != requestHeaders.end(); + requestHeaders.find("sec-websocket-key") != requestHeaders.end() && + permessage_deflate::isValid(requestHeaders); } void WebSocketProto_IETF::handshake(const std::string& url, @@ -65,21 +45,7 @@ void WebSocketProto_IETF::handshake(const std::string& url, pResponseHeaders->push_back( std::pair("Sec-WebSocket-Accept", response)); - auto swe = requestHeaders.find("sec-websocket-extensions"); - if (swe != requestHeaders.end()) { - auto extensions = split(swe->second, ","); - std::vector extInfos; - std::transform(extensions.begin(), extensions.end(), - std::back_inserter(extInfos), - parseExtensionInfo); - - for (auto pos = extInfos.begin(); pos != extInfos.end(); pos++) { - if (trim(pos->extension) == "permessage-deflate") { - pResponseHeaders->push_back(std::make_pair("Sec-WebSocket-Extensions", "permessage-deflate")); - pContext->permessageDeflate = true; - } - } - } + permessage_deflate::handshake(requestHeaders, pResponseHeaders, pContext); } bool WebSocketProto_IETF::isFin(uint8_t firstBit) const { diff --git a/src/websockets.cpp b/src/websockets.cpp index 199edbb7..3fe1f221 100644 --- a/src/websockets.cpp +++ b/src/websockets.cpp @@ -268,11 +268,11 @@ void WebSocketConnection::handshake(const std::string& url, if (_context.permessageDeflate) { int error; // TODO: Handle errors - error = _inflator.init(deflator::DeflateModeRaw); + error = _inflator.init(deflator::DeflateModeRaw, _context.clientMaxWindowBits); if (error != Z_OK) { debug_log("Failed to init inflator", LOG_ERROR); } - error = _deflator.init(deflator::DeflateModeRaw, Z_DEFAULT_COMPRESSION, 9); + error = _deflator.init(deflator::DeflateModeRaw, Z_DEFAULT_COMPRESSION, _context.serverMaxWindowBits); if (error != Z_OK) { debug_log("Failed to init deflator", LOG_ERROR); } @@ -415,7 +415,12 @@ void WebSocketConnection::onFrameComplete() { _payload.push_back(0xFF); _payload.push_back(0xFF); std::vector inflated(0); - _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated); + int error = _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated); + if (error != Z_OK) { + // TODO: Handle error + std::cerr << "Inflate failed with error " << error << "\n"; + } + _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(inflated), inflated.size()); } else { _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(_payload), _payload.size()); diff --git a/src/websockets.h b/src/websockets.h index 6010513f..59c849b1 100644 --- a/src/websockets.h +++ b/src/websockets.h @@ -190,7 +190,9 @@ class WebSocketConnection : WSParserCallbacks, NoCopy { : _pLoop(pLoop), _connState(WS_OPEN), _pCallbacks(callbacks), - _pParser(NULL) { + _pParser(NULL), + _deflator(), + _inflator() { ASSERT_BACKGROUND_THREAD() debug_log("WebSocketConnection::WebSocketConnection", LOG_DEBUG); diff --git a/src/wse-permessage-deflate.cpp b/src/wse-permessage-deflate.cpp new file mode 100644 index 00000000..8af0df19 --- /dev/null +++ b/src/wse-permessage-deflate.cpp @@ -0,0 +1,155 @@ +#include "wse-permessage-deflate.h" +#include +#include +#include +#include + +struct ExtensionInfo { + std::string extension; + std::map params; +}; + +ExtensionInfo parseExtensionInfo(const std::string str) { + std::vector parts = split(str, ";"); + + std::transform(parts.cbegin(), parts.cend(), parts.begin(), trim); + + std::string extension = parts[0]; + std::map params; + for (size_t i = 1; i < parts.size(); i++) { + std::vector param = split(parts[i], "="); + params[trim(param[0])] = param.size() > 1 ? trim(param[1]) : ""; + } + ExtensionInfo result; + result.extension = extension; + result.params = params; + return result; +} + +bool parseWindowBits(const ExtensionInfo& extInfo, const std::string& param, bool* present, int* value) { + *present = false; + *value = 0; + + auto match = extInfo.params.find(param); + if (match == extInfo.params.end()) { + // Not present--this is valid, so return true + return true; + } + + *present = true; + + if (match->second.size() == 0) { + // Param is present, but value is not + return true; + } + + if (match->second.size() > 2) { + // Too many digits + return false; + } + *value = atoi(match->second.c_str()); + if (*value < 8 || *value > 15) { + // Out of range + return false; + } + return true; +} + +bool parsePermessageDeflate(const ExtensionInfo& extInfo, WebSocketConnectionContext* pContext) { + if (extInfo.extension != "permessage-deflate") { + return false; + } + + pContext->permessageDeflate = true; + if (extInfo.params.find("server_no_context_takeover") != extInfo.params.end()) { + pContext->serverNoContextTakeover = true; + } + if (extInfo.params.find("client_no_context_takeover") != extInfo.params.end()) { + pContext->clientNoContextTakeover = true; + } + + bool hasServerMaxWindowBits; + bool hasClientMaxWindowBits; + + if (!parseWindowBits(extInfo, "server_max_window_bits", &hasServerMaxWindowBits, &pContext->serverMaxWindowBits)) { + return false; + } + if (hasServerMaxWindowBits && pContext->serverMaxWindowBits == 0) { + // If server_max_window_bits is present, the value is required + return false; + } + + if (!parseWindowBits(extInfo, "client_max_window_bits", &hasClientMaxWindowBits, &pContext->clientMaxWindowBits)) { + return false; + } + + // Set defaults + if (pContext->serverMaxWindowBits <= 0) { + pContext->serverMaxWindowBits = 15; + } + if (pContext->clientMaxWindowBits <= 0) { + pContext->clientMaxWindowBits = 15; + } + + return true; +} + +bool handle(const RequestHeaders& requestHeaders, + ResponseHeaders* pResponseHeaders, + WebSocketConnectionContext* pContext) { + + auto swe = requestHeaders.find("sec-websocket-extensions"); + if (swe != requestHeaders.end()) { + auto extensions = split(swe->second, ","); + std::vector extInfos; + std::transform(extensions.begin(), extensions.end(), + std::back_inserter(extInfos), + parseExtensionInfo); + + for (auto &extInfo : extInfos) { + if (trim(extInfo.extension) == "permessage-deflate") { + if (!parsePermessageDeflate(extInfo, pContext)) { + return false; + } + } + } + } + + if (pResponseHeaders && pContext->permessageDeflate) { + std::string params; + if (pContext->clientNoContextTakeover) { + params.append("; client_no_context_takeover"); + } + if (pContext->serverNoContextTakeover) { + params.append("; server_no_context_takeover"); + } + if (pContext->serverMaxWindowBits != 0) { + params.append("; server_max_window_bits="); + params.append(std::to_string(pContext->serverMaxWindowBits)); + } + if (pContext->clientMaxWindowBits != 0) { + params.append("; client_max_window_bits="); + params.append(std::to_string(pContext->clientMaxWindowBits)); + } + std::string exts = "permessage-deflate" + params; + pResponseHeaders->push_back(std::make_pair("Sec-WebSocket-Extensions", exts)); + } + + return true; +} + +namespace permessage_deflate { + +bool isValid(const RequestHeaders& requestHeaders) { + WebSocketConnectionContext context; + return handle(requestHeaders, NULL, &context); +} + +void handshake(const RequestHeaders& requestHeaders, + ResponseHeaders* pResponseHeaders, + WebSocketConnectionContext* pContext) { + + handle(requestHeaders, pResponseHeaders, pContext); +} + +} // namespace permessage_deflate \ No newline at end of file diff --git a/src/wse-permessage-deflate.h b/src/wse-permessage-deflate.h new file mode 100644 index 00000000..44cde463 --- /dev/null +++ b/src/wse-permessage-deflate.h @@ -0,0 +1,17 @@ +#ifndef WSEPERMESSAGEDEFLATE_H +#define WSEPERMESSAGEDEFLATE_H + +#include "constants.h" +#include "websockets-base.h" + +namespace permessage_deflate { + +bool isValid(const RequestHeaders& requestHeaders); + +void handshake(const RequestHeaders& requestHeaders, + ResponseHeaders* pResponseHeaders, + WebSocketConnectionContext* pContext); + +} // namespace permessage_deflate + +#endif \ No newline at end of file From c31836f794a00e5f57b2b22804bd51dbbba8d3ac Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 30 Aug 2021 22:20:41 -0700 Subject: [PATCH 5/7] Add Node.js test client --- deflate-client/.gitignore | 1 + deflate-client/README.md | 15 ++++++ deflate-client/index.mjs | 76 +++++++++++++++++++++++++++ deflate-client/lib/util.mjs | 89 ++++++++++++++++++++++++++++++++ deflate-client/package-lock.json | 43 +++++++++++++++ deflate-client/package.json | 12 +++++ deflate-client/server.R | 28 ++++++++++ src/websockets.cpp | 2 +- 8 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 deflate-client/.gitignore create mode 100644 deflate-client/README.md create mode 100644 deflate-client/index.mjs create mode 100644 deflate-client/lib/util.mjs create mode 100644 deflate-client/package-lock.json create mode 100644 deflate-client/package.json create mode 100644 deflate-client/server.R diff --git a/deflate-client/.gitignore b/deflate-client/.gitignore new file mode 100644 index 00000000..3c3629e6 --- /dev/null +++ b/deflate-client/.gitignore @@ -0,0 +1 @@ +node_modules diff --git a/deflate-client/README.md b/deflate-client/README.md new file mode 100644 index 00000000..9b35ebaa --- /dev/null +++ b/deflate-client/README.md @@ -0,0 +1,15 @@ +This directory contains Node.js based tests for exercising different WebSocket compression parameters in httpuv. + +To setup, you'll need Node.js installed, then: + +```sh +npm install +``` + +Be sure to also install the version of httpuv you want to test. + +To run the test: + +```sh +npm test +``` diff --git a/deflate-client/index.mjs b/deflate-client/index.mjs new file mode 100644 index 00000000..a862abdf --- /dev/null +++ b/deflate-client/index.mjs @@ -0,0 +1,76 @@ +import assert from "assert"; +import child_process from "child_process"; +import { randomBytes } from "crypto"; +import { sleep, md5, WebSocketReceiver, withWS } from "./lib/util.mjs"; + +const MAX_BITS = 21; + +async function runMD5Test(address, perMessageDeflate) { + await withWS(address + "md5", { perMessageDeflate }, async (ws, wsr) => { + process.stderr.write(" md5: "); + for (let i = 0; i <= MAX_BITS; i++) { + process.stderr.write("."); + const payload = randomBytes(2**i); + ws.send(payload); + const { data } = await wsr.nextEvent("message"); + assert.strictEqual(data.toString("utf-8"), md5(payload), "md5 of payload must match server's response"); + } + process.stderr.write("\n"); + }); +} + +async function runEchoTest(address, perMessageDeflate) { + await withWS(address + "echo", { perMessageDeflate }, async (ws, wsr) => { + process.stderr.write(" echo: "); + for (let i = 0; i <= MAX_BITS; i++) { + process.stderr.write("."); + const payload = randomBytes(2**i); + ws.send(payload); + const { data } = await wsr.nextEvent("message"); + assert(data.equals(payload), "payload must match server's response"); + } + process.stderr.write("\n"); + }); +} + +async function runTest(address, perMessageDeflate) { + await runMD5Test(address, perMessageDeflate); + await runEchoTest(address, perMessageDeflate); +} + +async function main() { + await sleep(1000); + await runTest("ws://127.0.0.1:9100/", false); + await runTest("ws://127.0.0.1:9100/", true); + await runTest("ws://127.0.0.1:9100/", { threshold: 0 }); + await runTest("ws://127.0.0.1:9100/", { + threshold: 0, + serverMaxWindowBits: 9, + clientMaxWindowBits: 9 + }); + await runTest("ws://127.0.0.1:9100/", { + threshold: 0, + serverMaxWindowBits: 9, + clientMaxWindowBits: 9, + serverNoContextTakeover: true + }); +} + +console.error("Launching httpuv"); +const rprocess = child_process.spawn("Rscript", ["server.R"], { + stdio: ["inherit", "inherit", "inherit"] +}); +process.on("exit", () => { + rprocess.kill(); +}); + +main().then( + result => { + console.error("All tests passed!"); + process.exit(0); + }, + error => { + console.error(error ?? "An unknown error has occurred!"); + process.exit(1); + } +); diff --git a/deflate-client/lib/util.mjs b/deflate-client/lib/util.mjs new file mode 100644 index 00000000..d8c10362 --- /dev/null +++ b/deflate-client/lib/util.mjs @@ -0,0 +1,89 @@ +import { createHash } from "crypto"; +import { WebSocket } from "ws"; + +export async function sleep(millis) { + await new Promise(resolve => { + setTimeout(() => { + resolve(undefined); + }, millis); + }); +} + +export function md5(data) { + const hash = createHash("md5"); + hash.update(data); + return hash.digest("hex"); +} + +export class WebSocketReceiver { + constructor(ws) { + this.messages = []; + this.pending = null; + + ws.on("open", () => { + this.#log("open"); + this.#push({type: "open"}); + }); + ws.on("message", (data, isBinary) => { + this.#log("message"); + this.#push({type: "message", data, isBinary}); + }); + ws.on("close", ({code, reason}) => { + this.#log("close"); + this.#push({type: "close", code, reason}); + }); + ws.on("error", error => { + this.#log("error"); + this.#push({type: "error", error}); + }); + } + + #log(...args) { + // console.log(...args); + } + + #push(message) { + this.messages.push(message); + if (this.messages.length === 1 && this.pending) { + const prevPending = this.pending + this.pending = null; + prevPending.resolve(null); + } + } + + async nextEvent(type = "message") { + while (this.messages.length === 0) { + if (!this.pending) { + let resolve, reject; + let promise = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + this.pending = {promise, resolve, reject}; + } + await this.pending.promise; + } + const msg = this.messages.shift(); + if (type && msg.type !== type) { + if (msg.type === "error") { + throw msg.error; + } else { + throw new Error(`Unexpected WebSocket event (expected '${type}', got '${msg.type}')`); + } + } + return msg; + } +} + +export async function withWS(address, options, callback) { + console.log("Connecting to", address, "with", options); + const ws = new WebSocket(address, options); + try { + const wsr = new WebSocketReceiver(ws); + await wsr.nextEvent("open"); + + return await callback(ws, wsr); + } finally { + ws.close(); + } +} diff --git a/deflate-client/package-lock.json b/deflate-client/package-lock.json new file mode 100644 index 00000000..eedef1a3 --- /dev/null +++ b/deflate-client/package-lock.json @@ -0,0 +1,43 @@ +{ + "name": "deflate-client", + "version": "1.0.0", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "version": "1.0.0", + "license": "GPL-2.0-or-later", + "dependencies": { + "ws": "^8.2.1" + } + }, + "node_modules/ws": { + "version": "8.2.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.2.1.tgz", + "integrity": "sha512-XkgWpJU3sHU7gX8f13NqTn6KQ85bd1WU7noBHTT8fSohx7OS1TPY8k+cyRPCzFkia7C4mM229yeHr1qK9sM4JQ==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + } + }, + "dependencies": { + "ws": { + "version": "8.2.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.2.1.tgz", + "integrity": "sha512-XkgWpJU3sHU7gX8f13NqTn6KQ85bd1WU7noBHTT8fSohx7OS1TPY8k+cyRPCzFkia7C4mM229yeHr1qK9sM4JQ==", + "requires": {} + } + } +} diff --git a/deflate-client/package.json b/deflate-client/package.json new file mode 100644 index 00000000..66c38c51 --- /dev/null +++ b/deflate-client/package.json @@ -0,0 +1,12 @@ +{ + "name": "deflate-client", + "version": "1.0.0", + "description": "This directory contains Node.js based tests for exercising different WebSocket compression parameters in httpuv.", + "main": "index.js", + "scripts": { + "test": "node index.mjs" + }, + "dependencies": { + "ws": "^8.2.1" + } +} diff --git a/deflate-client/server.R b/deflate-client/server.R new file mode 100644 index 00000000..01dd7fea --- /dev/null +++ b/deflate-client/server.R @@ -0,0 +1,28 @@ +library(httpuv) + + +echo_server <- function(ws) { + ws$onMessage(function(isBinary, data) { + ws$send(data) + }) +} + +md5_server <- function(ws) { + ws$onMessage(function(isBinary, data) { + ws$send(digest::digest(data, serialize = FALSE)) + }) +} + +server <- list( + onWSOpen = function(ws) { + if (identical(ws$request$PATH_INFO, "/echo")) { + echo_server(ws) + } else if (identical(ws$request$PATH_INFO, "/md5")) { + md5_server(ws) + } else { + ws$close() + } + } +) + +httpuv::runServer("127.0.0.1", 9100, server) \ No newline at end of file diff --git a/src/websockets.cpp b/src/websockets.cpp index 3fe1f221..2c96394a 100644 --- a/src/websockets.cpp +++ b/src/websockets.cpp @@ -290,7 +290,7 @@ void WebSocketConnection::sendWSMessage(Opcode opcode, const char* pData, size_t size_t footerLength = 0; std::vector deflated(0); - bool deflate = _context.permessageDeflate; + bool deflate = _context.permessageDeflate && (opcode == Continuation || opcode == Text || opcode == Binary); if (deflate) { // TODO: Handle errors int error = _deflator.deflate(pData, length, deflated); From d17abaa0a16403d9b89093417a1a8c91f8dbac83 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Tue, 31 Aug 2021 17:04:12 -0700 Subject: [PATCH 6/7] no_context_takeover support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This implements server_no_context_takeover and client_no_context_takeover support, but in a pretty naive way for now that gives correct results but doesn’t provide any memory efficiency advantage, which is the whole point of this feature. --- .Rbuildignore | 1 + deflate-client/index.mjs | 12 ++++++++++-- src/deflate.cpp | 8 ++++++++ src/deflate.h | 2 ++ src/websockets.cpp | 22 ++++++++++++++++++---- 5 files changed, 39 insertions(+), 6 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index 920f9259..97f5ab4d 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -22,3 +22,4 @@ ^cran-comments\.md$ ^CRAN-RELEASE$ ^CRAN-SUBMISSION$ +^deflate-client$ diff --git a/deflate-client/index.mjs b/deflate-client/index.mjs index a862abdf..56494228 100644 --- a/deflate-client/index.mjs +++ b/deflate-client/index.mjs @@ -7,7 +7,7 @@ const MAX_BITS = 21; async function runMD5Test(address, perMessageDeflate) { await withWS(address + "md5", { perMessageDeflate }, async (ws, wsr) => { - process.stderr.write(" md5: "); + process.stderr.write("Testing: "); for (let i = 0; i <= MAX_BITS; i++) { process.stderr.write("."); const payload = randomBytes(2**i); @@ -21,7 +21,7 @@ async function runMD5Test(address, perMessageDeflate) { async function runEchoTest(address, perMessageDeflate) { await withWS(address + "echo", { perMessageDeflate }, async (ws, wsr) => { - process.stderr.write(" echo: "); + process.stderr.write("Testing: "); for (let i = 0; i <= MAX_BITS; i++) { process.stderr.write("."); const payload = randomBytes(2**i); @@ -43,6 +43,7 @@ async function main() { await runTest("ws://127.0.0.1:9100/", false); await runTest("ws://127.0.0.1:9100/", true); await runTest("ws://127.0.0.1:9100/", { threshold: 0 }); + await runTest("ws://127.0.0.1:9100/", { threshold: 0, serverMaxWindowBits: 9 }); await runTest("ws://127.0.0.1:9100/", { threshold: 0, serverMaxWindowBits: 9, @@ -54,6 +55,13 @@ async function main() { clientMaxWindowBits: 9, serverNoContextTakeover: true }); + await runTest("ws://127.0.0.1:9100/", { + threshold: 0, + serverMaxWindowBits: 9, + clientMaxWindowBits: 9, + serverNoContextTakeover: true, + clientNoContextTakeover: true + }); } console.error("Launching httpuv"); diff --git a/src/deflate.cpp b/src/deflate.cpp index c68c83cf..af8cb9cc 100644 --- a/src/deflate.cpp +++ b/src/deflate.cpp @@ -83,6 +83,10 @@ int Deflator::init(DeflateMode mode, int level, int windowBits, int memLevel, in return result; } +int Deflator::reset() { + return deflateReset(&_stream); +} + int Deflator::deflate(const char* data, size_t data_len, std::vector& output) { if (_state != DeflatorStateReady) { throw std::runtime_error("Deflator.init() must be called before deflate()"); @@ -129,6 +133,10 @@ int Inflator::init(DeflateMode mode, int windowBits) { return result; } +int Inflator::reset() { + return inflateReset(&_stream); +} + int Inflator::inflate(const char* data, size_t data_len, std::vector& output) { if (_state != DeflatorStateReady) { throw std::runtime_error("Inflator.init() must be called before inflate()"); diff --git a/src/deflate.h b/src/deflate.h index 02addbf2..c65f3e1f 100644 --- a/src/deflate.h +++ b/src/deflate.h @@ -29,6 +29,7 @@ class Deflator { int init(DeflateMode mode, int level = Z_DEFAULT_COMPRESSION, int windowBits = 13, int memLevel = 8, int strategy = Z_DEFAULT_STRATEGY); + int reset(); int deflate(const char* data, size_t data_len, std::vector& output); }; @@ -41,6 +42,7 @@ class Inflator { Inflator(); ~Inflator(); int init(DeflateMode mode, int windowBits = 15); + int reset(); int inflate(const char* data, size_t data_len, std::vector& output); }; diff --git a/src/websockets.cpp b/src/websockets.cpp index 2c96394a..a9741ae1 100644 --- a/src/websockets.cpp +++ b/src/websockets.cpp @@ -292,8 +292,15 @@ void WebSocketConnection::sendWSMessage(Opcode opcode, const char* pData, size_t std::vector deflated(0); bool deflate = _context.permessageDeflate && (opcode == Continuation || opcode == Text || opcode == Binary); if (deflate) { + int error; // TODO: Handle errors - int error = _deflator.deflate(pData, length, deflated); + if (_context.serverNoContextTakeover) { + error = _deflator.reset(); + if (error != Z_OK) { + debug_log("An error occurred during deflate reset", LOG_ERROR); + } + } + error = _deflator.deflate(pData, length, deflated); if (error != Z_OK) { debug_log("An error occurred during deflate", LOG_ERROR); } @@ -415,10 +422,17 @@ void WebSocketConnection::onFrameComplete() { _payload.push_back(0xFF); _payload.push_back(0xFF); std::vector inflated(0); - int error = _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated); + int error; + // TODO: Handle errors + if (_context.clientNoContextTakeover) { + error = _inflator.reset(); + if (error != Z_OK) { + debug_log("An error occurred during inflate reset", LOG_ERROR); + } + } + error = _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated); if (error != Z_OK) { - // TODO: Handle error - std::cerr << "Inflate failed with error " << error << "\n"; + debug_log("An error occurred during inflate", LOG_ERROR); } _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(inflated), inflated.size()); From 1762ce77819ae542ed169bafa892fdc563c2ba19 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Sun, 21 May 2023 21:21:14 -0700 Subject: [PATCH 7/7] Change port number; try GHA --- .github/workflows/R-CMD-check.yaml | 28 +++++++++++++++++++++++++++- deflate-client/index.mjs | 14 +++++++------- deflate-client/server.R | 2 +- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index f6de5dbb..0aac0a21 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -7,7 +7,7 @@ on: branches: [main, rc-**] pull_request: schedule: - - cron: '0 6 * * 1' # every monday + - cron: "0 6 * * 1" # every monday name: Package checks @@ -20,3 +20,29 @@ jobs: uses: rstudio/shiny-workflows/.github/workflows/R-CMD-check.yaml@v1 with: ubuntu: "ubuntu-20.04 ubuntu-latest" + permessage-deflate-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + + - name: Install R package dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + + - name: Install R package + run: R CMD INSTALL . + + - name: Setup Node.js + uses: actions/setup-node@v2 + with: + node-version: "18" + + - name: Install npm dependencies + run: npm install + working-directory: deflate-client + + - name: Run npm tests + run: npm test + working-directory: deflate-client diff --git a/deflate-client/index.mjs b/deflate-client/index.mjs index 56494228..76825102 100644 --- a/deflate-client/index.mjs +++ b/deflate-client/index.mjs @@ -40,22 +40,22 @@ async function runTest(address, perMessageDeflate) { async function main() { await sleep(1000); - await runTest("ws://127.0.0.1:9100/", false); - await runTest("ws://127.0.0.1:9100/", true); - await runTest("ws://127.0.0.1:9100/", { threshold: 0 }); - await runTest("ws://127.0.0.1:9100/", { threshold: 0, serverMaxWindowBits: 9 }); - await runTest("ws://127.0.0.1:9100/", { + await runTest("ws://127.0.0.1:14252/", false); + await runTest("ws://127.0.0.1:14252/", true); + await runTest("ws://127.0.0.1:14252/", { threshold: 0 }); + await runTest("ws://127.0.0.1:14252/", { threshold: 0, serverMaxWindowBits: 9 }); + await runTest("ws://127.0.0.1:14252/", { threshold: 0, serverMaxWindowBits: 9, clientMaxWindowBits: 9 }); - await runTest("ws://127.0.0.1:9100/", { + await runTest("ws://127.0.0.1:14252/", { threshold: 0, serverMaxWindowBits: 9, clientMaxWindowBits: 9, serverNoContextTakeover: true }); - await runTest("ws://127.0.0.1:9100/", { + await runTest("ws://127.0.0.1:14252/", { threshold: 0, serverMaxWindowBits: 9, clientMaxWindowBits: 9, diff --git a/deflate-client/server.R b/deflate-client/server.R index 01dd7fea..d262fbc3 100644 --- a/deflate-client/server.R +++ b/deflate-client/server.R @@ -25,4 +25,4 @@ server <- list( } ) -httpuv::runServer("127.0.0.1", 9100, server) \ No newline at end of file +httpuv::runServer("127.0.0.1", 14252, server)