diff --git a/.Rbuildignore b/.Rbuildignore index 920f9259..97f5ab4d 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -22,3 +22,4 @@ ^cran-comments\.md$ ^CRAN-RELEASE$ ^CRAN-SUBMISSION$ +^deflate-client$ diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index f6de5dbb..0aac0a21 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -7,7 +7,7 @@ on: branches: [main, rc-**] pull_request: schedule: - - cron: '0 6 * * 1' # every monday + - cron: "0 6 * * 1" # every monday name: Package checks @@ -20,3 +20,29 @@ jobs: uses: rstudio/shiny-workflows/.github/workflows/R-CMD-check.yaml@v1 with: ubuntu: "ubuntu-20.04 ubuntu-latest" + permessage-deflate-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + + - name: Install R package dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + + - name: Install R package + run: R CMD INSTALL . + + - name: Setup Node.js + uses: actions/setup-node@v2 + with: + node-version: "18" + + - name: Install npm dependencies + run: npm install + working-directory: deflate-client + + - name: Run npm tests + run: npm test + working-directory: deflate-client diff --git a/deflate-client/.gitignore b/deflate-client/.gitignore new file mode 100644 index 00000000..3c3629e6 --- /dev/null +++ b/deflate-client/.gitignore @@ -0,0 +1 @@ +node_modules diff --git a/deflate-client/README.md b/deflate-client/README.md new file mode 100644 index 00000000..9b35ebaa --- /dev/null +++ b/deflate-client/README.md @@ -0,0 +1,15 @@ +This directory contains Node.js based tests for exercising different WebSocket compression parameters in httpuv. + +To setup, you'll need Node.js installed, then: + +```sh +npm install +``` + +Be sure to also install the version of httpuv you want to test. + +To run the test: + +```sh +npm test +``` diff --git a/deflate-client/index.mjs b/deflate-client/index.mjs new file mode 100644 index 00000000..76825102 --- /dev/null +++ b/deflate-client/index.mjs @@ -0,0 +1,84 @@ +import assert from "assert"; +import child_process from "child_process"; +import { randomBytes } from "crypto"; +import { sleep, md5, WebSocketReceiver, withWS } from "./lib/util.mjs"; + +const MAX_BITS = 21; + +async function runMD5Test(address, perMessageDeflate) { + await withWS(address + "md5", { perMessageDeflate }, async (ws, wsr) => { + process.stderr.write("Testing: "); + for (let i = 0; i <= MAX_BITS; i++) { + process.stderr.write("."); + const payload = randomBytes(2**i); + ws.send(payload); + const { data } = await wsr.nextEvent("message"); + assert.strictEqual(data.toString("utf-8"), md5(payload), "md5 of payload must match server's response"); + } + process.stderr.write("\n"); + }); +} + +async function runEchoTest(address, perMessageDeflate) { + await withWS(address + "echo", { perMessageDeflate }, async (ws, wsr) => { + process.stderr.write("Testing: "); + for (let i = 0; i <= MAX_BITS; i++) { + process.stderr.write("."); + const payload = randomBytes(2**i); + ws.send(payload); + const { data } = await wsr.nextEvent("message"); + assert(data.equals(payload), "payload must match server's response"); + } + process.stderr.write("\n"); + }); +} + +async function runTest(address, perMessageDeflate) { + await runMD5Test(address, perMessageDeflate); + await runEchoTest(address, perMessageDeflate); +} + +async function main() { + await sleep(1000); + await runTest("ws://127.0.0.1:14252/", false); + await runTest("ws://127.0.0.1:14252/", true); + await runTest("ws://127.0.0.1:14252/", { threshold: 0 }); + await runTest("ws://127.0.0.1:14252/", { threshold: 0, serverMaxWindowBits: 9 }); + await runTest("ws://127.0.0.1:14252/", { + threshold: 0, + serverMaxWindowBits: 9, + clientMaxWindowBits: 9 + }); + await runTest("ws://127.0.0.1:14252/", { + threshold: 0, + serverMaxWindowBits: 9, + clientMaxWindowBits: 9, + serverNoContextTakeover: true + }); + await runTest("ws://127.0.0.1:14252/", { + threshold: 0, + serverMaxWindowBits: 9, + clientMaxWindowBits: 9, + serverNoContextTakeover: true, + clientNoContextTakeover: true + }); +} + +console.error("Launching httpuv"); +const rprocess = child_process.spawn("Rscript", ["server.R"], { + stdio: ["inherit", "inherit", "inherit"] +}); +process.on("exit", () => { + rprocess.kill(); +}); + +main().then( + result => { + console.error("All tests passed!"); + process.exit(0); + }, + error => { + console.error(error ?? "An unknown error has occurred!"); + process.exit(1); + } +); diff --git a/deflate-client/lib/util.mjs b/deflate-client/lib/util.mjs new file mode 100644 index 00000000..d8c10362 --- /dev/null +++ b/deflate-client/lib/util.mjs @@ -0,0 +1,89 @@ +import { createHash } from "crypto"; +import { WebSocket } from "ws"; + +export async function sleep(millis) { + await new Promise(resolve => { + setTimeout(() => { + resolve(undefined); + }, millis); + }); +} + +export function md5(data) { + const hash = createHash("md5"); + hash.update(data); + return hash.digest("hex"); +} + +export class WebSocketReceiver { + constructor(ws) { + this.messages = []; + this.pending = null; + + ws.on("open", () => { + this.#log("open"); + this.#push({type: "open"}); + }); + ws.on("message", (data, isBinary) => { + this.#log("message"); + this.#push({type: "message", data, isBinary}); + }); + ws.on("close", ({code, reason}) => { + this.#log("close"); + this.#push({type: "close", code, reason}); + }); + ws.on("error", error => { + this.#log("error"); + this.#push({type: "error", error}); + }); + } + + #log(...args) { + // console.log(...args); + } + + #push(message) { + this.messages.push(message); + if (this.messages.length === 1 && this.pending) { + const prevPending = this.pending + this.pending = null; + prevPending.resolve(null); + } + } + + async nextEvent(type = "message") { + while (this.messages.length === 0) { + if (!this.pending) { + let resolve, reject; + let promise = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + this.pending = {promise, resolve, reject}; + } + await this.pending.promise; + } + const msg = this.messages.shift(); + if (type && msg.type !== type) { + if (msg.type === "error") { + throw msg.error; + } else { + throw new Error(`Unexpected WebSocket event (expected '${type}', got '${msg.type}')`); + } + } + return msg; + } +} + +export async function withWS(address, options, callback) { + console.log("Connecting to", address, "with", options); + const ws = new WebSocket(address, options); + try { + const wsr = new WebSocketReceiver(ws); + await wsr.nextEvent("open"); + + return await callback(ws, wsr); + } finally { + ws.close(); + } +} diff --git a/deflate-client/package-lock.json b/deflate-client/package-lock.json new file mode 100644 index 00000000..eedef1a3 --- /dev/null +++ b/deflate-client/package-lock.json @@ -0,0 +1,43 @@ +{ + "name": "deflate-client", + "version": "1.0.0", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "version": "1.0.0", + "license": "GPL-2.0-or-later", + "dependencies": { + "ws": "^8.2.1" + } + }, + "node_modules/ws": { + "version": "8.2.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.2.1.tgz", + "integrity": "sha512-XkgWpJU3sHU7gX8f13NqTn6KQ85bd1WU7noBHTT8fSohx7OS1TPY8k+cyRPCzFkia7C4mM229yeHr1qK9sM4JQ==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + } + }, + "dependencies": { + "ws": { + "version": "8.2.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.2.1.tgz", + "integrity": "sha512-XkgWpJU3sHU7gX8f13NqTn6KQ85bd1WU7noBHTT8fSohx7OS1TPY8k+cyRPCzFkia7C4mM229yeHr1qK9sM4JQ==", + "requires": {} + } + } +} diff --git a/deflate-client/package.json b/deflate-client/package.json new file mode 100644 index 00000000..66c38c51 --- /dev/null +++ b/deflate-client/package.json @@ -0,0 +1,12 @@ +{ + "name": "deflate-client", + "version": "1.0.0", + "description": "This directory contains Node.js based tests for exercising different WebSocket compression parameters in httpuv.", + "main": "index.js", + "scripts": { + "test": "node index.mjs" + }, + "dependencies": { + "ws": "^8.2.1" + } +} diff --git a/deflate-client/server.R b/deflate-client/server.R new file mode 100644 index 00000000..d262fbc3 --- /dev/null +++ b/deflate-client/server.R @@ -0,0 +1,28 @@ +library(httpuv) + + +echo_server <- function(ws) { + ws$onMessage(function(isBinary, data) { + ws$send(data) + }) +} + +md5_server <- function(ws) { + ws$onMessage(function(isBinary, data) { + ws$send(digest::digest(data, serialize = FALSE)) + }) +} + +server <- list( + onWSOpen = function(ws) { + if (identical(ws$request$PATH_INFO, "/echo")) { + echo_server(ws) + } else if (identical(ws$request$PATH_INFO, "/md5")) { + md5_server(ws) + } else { + ws$close() + } + } +) + +httpuv::runServer("127.0.0.1", 14252, server) diff --git a/src/constants.h b/src/constants.h index a7bc98bb..a582470b 100644 --- a/src/constants.h +++ b/src/constants.h @@ -54,4 +54,23 @@ static inline std::string trim(const std::string &s) { return s.substr(start, end-start); } +static inline std::vector split(const std::string& s, const std::string& delim) { + std::vector results; + size_t pos = 0; + while (true) { + size_t i = s.find(delim, pos); + if (i == std::string::npos) { + break; + } + if (i != pos) { + results.push_back(s.substr(pos, i - pos)); + } + pos = i + 1; + } + if (pos != s.length()) { + results.push_back(s.substr(pos)); + } + return results; +} + #endif // CONSTANTS_H diff --git a/src/deflate.cpp b/src/deflate.cpp new file mode 100644 index 00000000..af8cb9cc --- /dev/null +++ b/src/deflate.cpp @@ -0,0 +1,147 @@ +#include "deflate.h" +#include +#include +#include "utils.h" + +namespace deflator { + +typedef int(*flate_func)(z_stream* strm, int flush); + +// Generic function for driving I/O loop for inflate/deflate +int flate(flate_func func, z_stream* strm, const char* data, size_t data_len, std::vector& output) { + int error; + + const size_t CHUNK_SIZE = 256; + unsigned char temp[CHUNK_SIZE]; + memset(temp, 0, CHUNK_SIZE); + + strm->next_in = reinterpret_cast(const_cast(data)); + strm->avail_in = data_len; + + do { + strm->next_out = temp; + strm->avail_out = CHUNK_SIZE; + error = func(strm, Z_NO_FLUSH); + if (error != Z_OK && error != Z_BUF_ERROR) { + debug_log("flate failed", LOG_INFO); + return error; + } + // This might not copy any data; sometimes deflate() just populates buffers + output.insert(output.end(), (char*)temp, (char*)(temp + CHUNK_SIZE - strm->avail_out)); + } while (strm->avail_out == 0 || strm->avail_in > 0); + + do { + strm->next_out = temp; + strm->avail_out = CHUNK_SIZE; + error = func(strm, Z_SYNC_FLUSH); + if (error != Z_OK && error != Z_BUF_ERROR) { + debug_log("flate failed", LOG_INFO); + return error; + } + output.insert(output.end(), (char*)temp, (char*)(temp + CHUNK_SIZE - strm->avail_out)); + } while (strm->avail_out == 0); + + return Z_OK; +} + +Deflator::Deflator() { + _stream = {0}; + _state = DeflatorStateStart; +} + +Deflator::~Deflator() { + if (_state == DeflatorStateReady) { + int error = deflateEnd(&_stream); + if (error == Z_STREAM_ERROR) { + // deflateEnd can return other errors, but they're nothing to worry about. + // https://stackoverflow.com/a/19816633/139922 + debug_log("deflateEnd failed", LOG_WARN); + } + } + _state = DeflatorStateEnded; +} + +int Deflator::init(DeflateMode mode, int level, int windowBits, int memLevel, int strategy) { + if (_state != DeflatorStateStart) { + throw std::runtime_error("Deflator.init() was called twice"); + } + + if (windowBits < 9 || windowBits > 15) { + throw std::runtime_error("Deflator.init() expects 9 <= windowBits <= 15"); + } + + if (mode == DeflateModeRaw) { + windowBits *= -1; + } else if (mode == DeflateModeGzip) { + windowBits += 16; + } + + int result = deflateInit2(&_stream, level, Z_DEFLATED, windowBits, memLevel, strategy); + if (result == Z_OK) { + _state = DeflatorStateReady; + } + return result; +} + +int Deflator::reset() { + return deflateReset(&_stream); +} + +int Deflator::deflate(const char* data, size_t data_len, std::vector& output) { + if (_state != DeflatorStateReady) { + throw std::runtime_error("Deflator.init() must be called before deflate()"); + } + return flate(::deflate, &_stream, data, data_len, output); +} + +Inflator::Inflator() { + _stream = {0}; + _state = DeflatorStateStart; +} + +Inflator::~Inflator() { + if (_state == DeflatorStateReady) { + int error = inflateEnd(&_stream); + if (error == Z_STREAM_ERROR) { + // inflateEnd can return other errors, but they're nothing to worry about. + // https://stackoverflow.com/a/19816633/139922 + debug_log("inflateEnd failed", LOG_WARN); + } + } + _state = DeflatorStateEnded; +} + +int Inflator::init(DeflateMode mode, int windowBits) { + if (_state != DeflatorStateStart) { + throw std::runtime_error("Inflator.init() was called twice"); + } + + if (windowBits < 8 || windowBits > 15) { + throw std::runtime_error("Inflator.init() expects 9 <= windowBits <= 15"); + } + + if (mode == DeflateModeRaw) { + windowBits *= -1; + } else if (mode == DeflateModeGzip) { + windowBits += 16; + } + + int result = inflateInit2(&_stream, windowBits); + if (result == Z_OK) { + _state = DeflatorStateReady; + } + return result; +} + +int Inflator::reset() { + return inflateReset(&_stream); +} + +int Inflator::inflate(const char* data, size_t data_len, std::vector& output) { + if (_state != DeflatorStateReady) { + throw std::runtime_error("Inflator.init() must be called before inflate()"); + } + return flate(::inflate, &_stream, data, data_len, output); +} + +} // namespace diff --git a/src/deflate.h b/src/deflate.h new file mode 100644 index 00000000..c65f3e1f --- /dev/null +++ b/src/deflate.h @@ -0,0 +1,51 @@ +#ifndef DEFLATE_H +#define DEFLATE_H + +#include +#include + +namespace deflator { + +enum DeflateMode { + DeflateModeZlib, + DeflateModeRaw, + DeflateModeGzip, +}; + +enum DeflatorState { + DeflatorStateStart, + DeflatorStateReady, + DeflatorStateEnded +}; + +class Deflator { +private: + z_stream _stream; + DeflatorState _state; + +public: + Deflator(); + ~Deflator(); + int init(DeflateMode mode, int level = Z_DEFAULT_COMPRESSION, + int windowBits = 13, int memLevel = 8, + int strategy = Z_DEFAULT_STRATEGY); + int reset(); + int deflate(const char* data, size_t data_len, std::vector& output); +}; + +class Inflator { +private: + z_stream _stream; + DeflatorState _state; + +public: + Inflator(); + ~Inflator(); + int init(DeflateMode mode, int windowBits = 15); + int reset(); + int inflate(const char* data, size_t data_len, std::vector& output); +}; + +} + +#endif // DEFLATE_H diff --git a/src/websockets-base.cpp b/src/websockets-base.cpp index 29b92288..ef0c1835 100644 --- a/src/websockets-base.cpp +++ b/src/websockets-base.cpp @@ -25,7 +25,7 @@ void swapByteOrder(unsigned char* pStart, unsigned char* pEnd) { } void WebSocketProto::createFrameHeader( - Opcode opcode, bool mask, size_t payloadSize, int32_t maskingKey, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pData[MAX_HEADER_BYTES], size_t* pLen) const { unsigned char* pBuf = (unsigned char*)pData; @@ -35,6 +35,7 @@ void WebSocketProto::createFrameHeader( pBuf[0] = toFin(true) << 7 | // FIN; always true + (rsv1 << 6) | encodeOpcode(opcode); pBuf[1] = mask ? 1 << 7 : 0; if (payloadSize_64 <= 125) { diff --git a/src/websockets-base.h b/src/websockets-base.h index 8781f5f3..f6775062 100644 --- a/src/websockets-base.h +++ b/src/websockets-base.h @@ -5,6 +5,22 @@ #include "constants.h" +struct WebSocketConnectionContext { + WebSocketConnectionContext() : + permessageDeflate(false), + clientMaxWindowBits(-1), + serverMaxWindowBits(-1), + clientNoContextTakeover(false), + serverNoContextTakeover(false) { + } + + bool permessageDeflate; + int clientMaxWindowBits; + int serverMaxWindowBits; + bool clientNoContextTakeover; + bool serverNoContextTakeover; +}; + class WebSocketProto { public: @@ -22,9 +38,10 @@ class WebSocketProto { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const = 0; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const = 0; - void createFrameHeader(Opcode opcode, bool mask, size_t payloadSize, + void createFrameHeader(Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pData[MAX_HEADER_BYTES], size_t* pLen) const; diff --git a/src/websockets-hixie76.cpp b/src/websockets-hixie76.cpp index 5d0b9d83..960abb12 100644 --- a/src/websockets-hixie76.cpp +++ b/src/websockets-hixie76.cpp @@ -5,13 +5,14 @@ void WSHixie76Parser::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { _hybi03.handshake(url, requestHeaders, ppData, pLen, responseHeaders, - pResponse); + pResponse, pContext); } void WSHixie76Parser::createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen) const { diff --git a/src/websockets-hixie76.h b/src/websockets-hixie76.h index c8cbb0e9..332e89c4 100644 --- a/src/websockets-hixie76.h +++ b/src/websockets-hixie76.h @@ -38,10 +38,11 @@ class WSHixie76Parser : public WSParser { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; void createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen diff --git a/src/websockets-hybi03.cpp b/src/websockets-hybi03.cpp index c579e25c..439fe08c 100644 --- a/src/websockets-hybi03.cpp +++ b/src/websockets-hybi03.cpp @@ -54,7 +54,8 @@ void WebSocketProto_HyBi03::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { assert(*pLen >= 8); diff --git a/src/websockets-hybi03.h b/src/websockets-hybi03.h index 3cc301ae..db68224d 100644 --- a/src/websockets-hybi03.h +++ b/src/websockets-hybi03.h @@ -16,9 +16,10 @@ class WebSocketProto_HyBi03 : public WebSocketProto { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; - void createFrameHeader(Opcode opcode, bool mask, size_t payloadSize, + void createFrameHeader(Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pData[MAX_HEADER_BYTES], size_t* pLen) const; diff --git a/src/websockets-ietf.cpp b/src/websockets-ietf.cpp index f97101c3..8290bc7e 100644 --- a/src/websockets-ietf.cpp +++ b/src/websockets-ietf.cpp @@ -1,23 +1,30 @@ #include "websockets-ietf.h" +#include +#include + #include "utils.h" #include "sha1/sha1.h" #include "base64/base64.hpp" +#include "wse-permessage-deflate.h" + bool WebSocketProto_IETF::canHandle(const RequestHeaders& requestHeaders, const char* pData, size_t len) const { return requestHeaders.find("upgrade") != requestHeaders.end() && strcasecmp(requestHeaders.at("upgrade").c_str(), "websocket") == 0 && - requestHeaders.find("sec-websocket-key") != requestHeaders.end(); + requestHeaders.find("sec-websocket-key") != requestHeaders.end() && + permessage_deflate::isValid(requestHeaders); } void WebSocketProto_IETF::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { std::string key = requestHeaders.at("sec-websocket-key"); @@ -37,6 +44,8 @@ void WebSocketProto_IETF::handshake(const std::string& url, std::pair("Upgrade", "websocket")); pResponseHeaders->push_back( std::pair("Sec-WebSocket-Accept", response)); + + permessage_deflate::handshake(requestHeaders, pResponseHeaders, pContext); } bool WebSocketProto_IETF::isFin(uint8_t firstBit) const { diff --git a/src/websockets-ietf.h b/src/websockets-ietf.h index dd166c7a..69204b35 100644 --- a/src/websockets-ietf.h +++ b/src/websockets-ietf.h @@ -16,7 +16,8 @@ class WebSocketProto_IETF : public WebSocketProto { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; bool isFin(uint8_t firstBit) const; uint8_t toFin(bool isFin) const; diff --git a/src/websockets.cpp b/src/websockets.cpp index f5ece8f2..a9741ae1 100644 --- a/src/websockets.cpp +++ b/src/websockets.cpp @@ -45,6 +45,7 @@ bool WSHyBiFrameHeader::isHeaderComplete() const { WSFrameHeaderInfo WSHyBiFrameHeader::info() const { WSFrameHeaderInfo inf; inf.fin = fin(); + inf.rsv1 = rsv1(); inf.opcode = opcode(); inf.hasLength = true; inf.masked = masked(); @@ -59,6 +60,9 @@ WSFrameHeaderInfo WSHyBiFrameHeader::info() const { bool WSHyBiFrameHeader::fin() const { return _pProto->isFin(read(0, 1)); } +bool WSHyBiFrameHeader::rsv1() const { + return read(1, 1); +} Opcode WSHyBiFrameHeader::opcode() const { uint8_t oc = read(4, 4); return _pProto->decodeOpcode(oc); @@ -140,19 +144,20 @@ void WSHyBiParser::handshake(const std::string& url, const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* pResponseHeaders, - std::vector* pResponse) const { + std::vector* pResponse, + WebSocketConnectionContext* pContext) const { ASSERT_BACKGROUND_THREAD() _pProto->handshake(url, requestHeaders, ppData, pLen, pResponseHeaders, - pResponse); + pResponse, pContext); } void WSHyBiParser::createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen ) const { - _pProto->createFrameHeader(opcode, mask, payloadSize, maskingKey, + _pProto->createFrameHeader(opcode, rsv1, mask, payloadSize, maskingKey, pHeaderData, pHeaderLen); } @@ -258,7 +263,20 @@ void WebSocketConnection::handshake(const std::string& url, if (_connState == WS_CLOSED) return; _pParser->handshake(url, requestHeaders, ppData, pLen, pResponseHeaders, - pResponse); + pResponse, &_context); + + if (_context.permessageDeflate) { + int error; + // TODO: Handle errors + error = _inflator.init(deflator::DeflateModeRaw, _context.clientMaxWindowBits); + if (error != Z_OK) { + debug_log("Failed to init inflator", LOG_ERROR); + } + error = _deflator.init(deflator::DeflateModeRaw, Z_DEFAULT_COMPRESSION, _context.serverMaxWindowBits); + if (error != Z_OK) { + debug_log("Failed to init deflator", LOG_ERROR); + } + } } void WebSocketConnection::sendWSMessage(Opcode opcode, const char* pData, size_t length) { @@ -271,7 +289,30 @@ void WebSocketConnection::sendWSMessage(Opcode opcode, const char* pData, size_t size_t headerLength = 0; size_t footerLength = 0; - _pParser->createFrameHeaderFooter(opcode, false, length, 0, + std::vector deflated(0); + bool deflate = _context.permessageDeflate && (opcode == Continuation || opcode == Text || opcode == Binary); + if (deflate) { + int error; + // TODO: Handle errors + if (_context.serverNoContextTakeover) { + error = _deflator.reset(); + if (error != Z_OK) { + debug_log("An error occurred during deflate reset", LOG_ERROR); + } + } + error = _deflator.deflate(pData, length, deflated); + if (error != Z_OK) { + debug_log("An error occurred during deflate", LOG_ERROR); + } + deflated.pop_back(); + deflated.pop_back(); + deflated.pop_back(); + deflated.pop_back(); + pData = safe_vec_addr(deflated); + length = deflated.size(); + } + + _pParser->createFrameHeaderFooter(opcode, deflate, false, length, 0, safe_vec_addr(header), &headerLength, safe_vec_addr(footer), &footerLength); header.resize(headerLength); @@ -375,7 +416,29 @@ void WebSocketConnection::onFrameComplete() { } case Text: case Binary: { - _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(_payload), _payload.size()); + if (_context.permessageDeflate && _header.rsv1) { + _payload.push_back(0); + _payload.push_back(0); + _payload.push_back(0xFF); + _payload.push_back(0xFF); + std::vector inflated(0); + int error; + // TODO: Handle errors + if (_context.clientNoContextTakeover) { + error = _inflator.reset(); + if (error != Z_OK) { + debug_log("An error occurred during inflate reset", LOG_ERROR); + } + } + error = _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated); + if (error != Z_OK) { + debug_log("An error occurred during inflate", LOG_ERROR); + } + + _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(inflated), inflated.size()); + } else { + _pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(_payload), _payload.size()); + } break; } case Close: { diff --git a/src/websockets.h b/src/websockets.h index ce26ab68..59c849b1 100644 --- a/src/websockets.h +++ b/src/websockets.h @@ -9,6 +9,7 @@ #include #include "utils.h" +#include "deflate.h" #include "thread.h" #include "constants.h" #include "websockets-base.h" @@ -17,6 +18,7 @@ class WSFrameHeaderInfo { public: bool fin; + bool rsv1; Opcode opcode; bool masked; std::vector maskingKey; @@ -56,6 +58,9 @@ class WSHyBiFrameHeader { private: bool fin() const; + bool rsv1() const; + bool rsv2() const; + bool rsv3() const; Opcode opcode() const; bool masked() const; void maskingKey(uint8_t key[4]) const; @@ -92,10 +97,11 @@ class WSParser { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const = 0; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const = 0; virtual void createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen @@ -125,10 +131,11 @@ class WSHyBiParser : public WSParser { const RequestHeaders& requestHeaders, char** ppData, size_t* pLen, ResponseHeaders* responseHeaders, - std::vector* pResponse) const; + std::vector* pResponse, + WebSocketConnectionContext* pContext) const; void createFrameHeaderFooter( - Opcode opcode, bool mask, size_t payloadSize, + Opcode opcode, bool rsv1, bool mask, size_t payloadSize, int32_t maskingKey, char pHeaderData[MAX_HEADER_BYTES], size_t* pHeaderLen, char pFooterData[MAX_FOOTER_BYTES], size_t* pFooterLen @@ -172,6 +179,9 @@ class WebSocketConnection : WSParserCallbacks, NoCopy { std::vector _incompleteContentPayload; std::vector _payload; uv_timer_t* _pPingTimer; + WebSocketConnectionContext _context; + deflator::Deflator _deflator; + deflator::Inflator _inflator; public: WebSocketConnection( @@ -180,7 +190,9 @@ class WebSocketConnection : WSParserCallbacks, NoCopy { : _pLoop(pLoop), _connState(WS_OPEN), _pCallbacks(callbacks), - _pParser(NULL) { + _pParser(NULL), + _deflator(), + _inflator() { ASSERT_BACKGROUND_THREAD() debug_log("WebSocketConnection::WebSocketConnection", LOG_DEBUG); diff --git a/src/wse-permessage-deflate.cpp b/src/wse-permessage-deflate.cpp new file mode 100644 index 00000000..8af0df19 --- /dev/null +++ b/src/wse-permessage-deflate.cpp @@ -0,0 +1,155 @@ +#include "wse-permessage-deflate.h" +#include +#include +#include +#include + +struct ExtensionInfo { + std::string extension; + std::map params; +}; + +ExtensionInfo parseExtensionInfo(const std::string str) { + std::vector parts = split(str, ";"); + + std::transform(parts.cbegin(), parts.cend(), parts.begin(), trim); + + std::string extension = parts[0]; + std::map params; + for (size_t i = 1; i < parts.size(); i++) { + std::vector param = split(parts[i], "="); + params[trim(param[0])] = param.size() > 1 ? trim(param[1]) : ""; + } + ExtensionInfo result; + result.extension = extension; + result.params = params; + return result; +} + +bool parseWindowBits(const ExtensionInfo& extInfo, const std::string& param, bool* present, int* value) { + *present = false; + *value = 0; + + auto match = extInfo.params.find(param); + if (match == extInfo.params.end()) { + // Not present--this is valid, so return true + return true; + } + + *present = true; + + if (match->second.size() == 0) { + // Param is present, but value is not + return true; + } + + if (match->second.size() > 2) { + // Too many digits + return false; + } + *value = atoi(match->second.c_str()); + if (*value < 8 || *value > 15) { + // Out of range + return false; + } + return true; +} + +bool parsePermessageDeflate(const ExtensionInfo& extInfo, WebSocketConnectionContext* pContext) { + if (extInfo.extension != "permessage-deflate") { + return false; + } + + pContext->permessageDeflate = true; + if (extInfo.params.find("server_no_context_takeover") != extInfo.params.end()) { + pContext->serverNoContextTakeover = true; + } + if (extInfo.params.find("client_no_context_takeover") != extInfo.params.end()) { + pContext->clientNoContextTakeover = true; + } + + bool hasServerMaxWindowBits; + bool hasClientMaxWindowBits; + + if (!parseWindowBits(extInfo, "server_max_window_bits", &hasServerMaxWindowBits, &pContext->serverMaxWindowBits)) { + return false; + } + if (hasServerMaxWindowBits && pContext->serverMaxWindowBits == 0) { + // If server_max_window_bits is present, the value is required + return false; + } + + if (!parseWindowBits(extInfo, "client_max_window_bits", &hasClientMaxWindowBits, &pContext->clientMaxWindowBits)) { + return false; + } + + // Set defaults + if (pContext->serverMaxWindowBits <= 0) { + pContext->serverMaxWindowBits = 15; + } + if (pContext->clientMaxWindowBits <= 0) { + pContext->clientMaxWindowBits = 15; + } + + return true; +} + +bool handle(const RequestHeaders& requestHeaders, + ResponseHeaders* pResponseHeaders, + WebSocketConnectionContext* pContext) { + + auto swe = requestHeaders.find("sec-websocket-extensions"); + if (swe != requestHeaders.end()) { + auto extensions = split(swe->second, ","); + std::vector extInfos; + std::transform(extensions.begin(), extensions.end(), + std::back_inserter(extInfos), + parseExtensionInfo); + + for (auto &extInfo : extInfos) { + if (trim(extInfo.extension) == "permessage-deflate") { + if (!parsePermessageDeflate(extInfo, pContext)) { + return false; + } + } + } + } + + if (pResponseHeaders && pContext->permessageDeflate) { + std::string params; + if (pContext->clientNoContextTakeover) { + params.append("; client_no_context_takeover"); + } + if (pContext->serverNoContextTakeover) { + params.append("; server_no_context_takeover"); + } + if (pContext->serverMaxWindowBits != 0) { + params.append("; server_max_window_bits="); + params.append(std::to_string(pContext->serverMaxWindowBits)); + } + if (pContext->clientMaxWindowBits != 0) { + params.append("; client_max_window_bits="); + params.append(std::to_string(pContext->clientMaxWindowBits)); + } + std::string exts = "permessage-deflate" + params; + pResponseHeaders->push_back(std::make_pair("Sec-WebSocket-Extensions", exts)); + } + + return true; +} + +namespace permessage_deflate { + +bool isValid(const RequestHeaders& requestHeaders) { + WebSocketConnectionContext context; + return handle(requestHeaders, NULL, &context); +} + +void handshake(const RequestHeaders& requestHeaders, + ResponseHeaders* pResponseHeaders, + WebSocketConnectionContext* pContext) { + + handle(requestHeaders, pResponseHeaders, pContext); +} + +} // namespace permessage_deflate \ No newline at end of file diff --git a/src/wse-permessage-deflate.h b/src/wse-permessage-deflate.h new file mode 100644 index 00000000..44cde463 --- /dev/null +++ b/src/wse-permessage-deflate.h @@ -0,0 +1,17 @@ +#ifndef WSEPERMESSAGEDEFLATE_H +#define WSEPERMESSAGEDEFLATE_H + +#include "constants.h" +#include "websockets-base.h" + +namespace permessage_deflate { + +bool isValid(const RequestHeaders& requestHeaders); + +void handshake(const RequestHeaders& requestHeaders, + ResponseHeaders* pResponseHeaders, + WebSocketConnectionContext* pContext); + +} // namespace permessage_deflate + +#endif \ No newline at end of file