diff --git a/mod_websocket.c b/mod_websocket.c index b5e73be..2964ee0 100644 --- a/mod_websocket.c +++ b/mod_websocket.c @@ -25,8 +25,10 @@ */ #include "apr_base64.h" +#include "apr_queue.h" #include "apr_sha1.h" #include "apr_strings.h" +#include "apr_thread_cond.h" #include "httpd.h" #include "http_config.h" @@ -58,6 +60,8 @@ typedef struct #define BLOCK_DATA_SIZE 4096 +#define QUEUE_CAPACITY 16 + #define DATA_FRAMING_MASK 0 #define DATA_FRAMING_START 1 #define DATA_FRAMING_PAYLOAD_LENGTH 2 @@ -214,10 +218,14 @@ typedef struct _WebSocketState { request_rec *r; apr_bucket_brigade *obb; + apr_os_thread_t main_thread; apr_thread_mutex_t *mutex; + apr_thread_cond_t *cond; apr_array_header_t *protocols; int closing; apr_int64_t protocol_version; + apr_pollset_t *pollset; + apr_queue_t *queue; } WebSocketState; static request_rec *CALLBACK mod_websocket_request(const WebSocketServer *server) @@ -290,85 +298,169 @@ static void CALLBACK mod_websocket_protocol_set(const WebSocketServer *server, } } +/* + * Sends data to the WebSocket connection using the given server state. The + * server state must be locked upon entering this function. buffer_size is + * assumed to be within the limits defined by the WebSocket protocol (i.e. fits + * in 63 bits). + */ +static size_t mod_websocket_send_internal(WebSocketState *state, + const int type, + const unsigned char *buffer, + const size_t buffer_size) +{ + apr_uint64_t payload_length = + (apr_uint64_t) ((buffer != NULL) ? buffer_size : 0); + size_t written = 0; + + if ((state->r != NULL) && (state->obb != NULL) && !state->closing) { + unsigned char header[32]; + ap_filter_t *of = state->r->connection->output_filters; + apr_size_t pos = 0; + unsigned char opcode; + + switch (type) { + case MESSAGE_TYPE_TEXT: + opcode = OPCODE_TEXT; + break; + case MESSAGE_TYPE_BINARY: + opcode = OPCODE_BINARY; + break; + case MESSAGE_TYPE_PING: + opcode = OPCODE_PING; + break; + case MESSAGE_TYPE_PONG: + opcode = OPCODE_PONG; + break; + case MESSAGE_TYPE_CLOSE: + default: + state->closing = 1; + opcode = OPCODE_CLOSE; + break; + } + header[pos++] = FRAME_SET_FIN(1) | FRAME_SET_OPCODE(opcode); + if (payload_length < 126) { + header[pos++] = + FRAME_SET_MASK(0) | FRAME_SET_LENGTH(payload_length, 0); + } + else { + if (payload_length < 65536) { + header[pos++] = FRAME_SET_MASK(0) | 126; + } + else { + header[pos++] = FRAME_SET_MASK(0) | 127; + header[pos++] = FRAME_SET_LENGTH(payload_length, 7); + header[pos++] = FRAME_SET_LENGTH(payload_length, 6); + header[pos++] = FRAME_SET_LENGTH(payload_length, 5); + header[pos++] = FRAME_SET_LENGTH(payload_length, 4); + header[pos++] = FRAME_SET_LENGTH(payload_length, 3); + header[pos++] = FRAME_SET_LENGTH(payload_length, 2); + } + header[pos++] = FRAME_SET_LENGTH(payload_length, 1); + header[pos++] = FRAME_SET_LENGTH(payload_length, 0); + } + ap_fwrite(of, state->obb, (const char *)header, pos); /* Header */ + if (payload_length > 0) { + if (ap_fwrite(of, state->obb, + (const char *)buffer, + buffer_size) == APR_SUCCESS) { /* Payload Data */ + written = buffer_size; + } + } + if (ap_fflush(of, state->obb) != APR_SUCCESS) { + written = 0; + } + } + + return written; +} + +typedef struct +{ + int type; + const unsigned char * buffer; + size_t buffer_size; + int done; + size_t written; +} WebSocketMessageData; + +/* + * Sends a buffer of data via the WebSocket. Returns the number of bytes that + * are actually written. + * + * If this function is called from a different thread than the one running the + * main framing loop, the message will be queued and the calling thread will + * block until the data is written by the main thread. + */ static size_t CALLBACK mod_websocket_plugin_send(const WebSocketServer *server, const int type, const unsigned char *buffer, const size_t buffer_size) { - apr_uint64_t payload_length = - (apr_uint64_t) ((buffer != NULL) ? buffer_size : 0); size_t written = 0; /* Deal with size more that 63 bits - FIXME */ - + /* FIXME - if sending a zero-length message, the API cannot distinguish + * between success and failure */ if ((server != NULL) && (server->state != NULL)) { WebSocketState *state = server->state; apr_thread_mutex_lock(state->mutex); - if ((state->r != NULL) && (state->obb != NULL) && !state->closing) { - unsigned char header[32]; - ap_filter_t *of = state->r->connection->output_filters; - apr_size_t pos = 0; - unsigned char opcode; - - switch (type) { - case MESSAGE_TYPE_TEXT: - opcode = OPCODE_TEXT; - break; - case MESSAGE_TYPE_BINARY: - opcode = OPCODE_BINARY; - break; - case MESSAGE_TYPE_PING: - opcode = OPCODE_PING; - break; - case MESSAGE_TYPE_PONG: - opcode = OPCODE_PONG; - break; - case MESSAGE_TYPE_CLOSE: - default: - state->closing = 1; - opcode = OPCODE_CLOSE; - break; - } - header[pos++] = FRAME_SET_FIN(1) | FRAME_SET_OPCODE(opcode); - if (payload_length < 126) { - header[pos++] = - FRAME_SET_MASK(0) | FRAME_SET_LENGTH(payload_length, 0); + if (apr_os_thread_equal(apr_os_thread_current(), state->main_thread)) { + /* This is the main thread. It's safe to write messages directly. */ + written = mod_websocket_send_internal(state, type, buffer, buffer_size); + } + else if ((state->pollset != NULL) && (state->queue != NULL) && + !state->closing) { + /* Dispatch this message to the main thread. */ + apr_status_t rv; + WebSocketMessageData msg = { 0 }; + + /* Populate the message data. */ + msg.type = type; + msg.buffer = buffer; + msg.buffer_size = buffer_size; + + /* Queue the message. */ + do { + rv = apr_queue_push(state->queue, &msg); + } while (APR_STATUS_IS_EINTR(rv)); + + if (rv != APR_SUCCESS) { + /* Couldn't push the message onto the queue. */ + goto send_unlock; } - else { - if (payload_length < 65536) { - header[pos++] = FRAME_SET_MASK(0) | 126; - } - else { - header[pos++] = FRAME_SET_MASK(0) | 127; - header[pos++] = FRAME_SET_LENGTH(payload_length, 7); - header[pos++] = FRAME_SET_LENGTH(payload_length, 6); - header[pos++] = FRAME_SET_LENGTH(payload_length, 5); - header[pos++] = FRAME_SET_LENGTH(payload_length, 4); - header[pos++] = FRAME_SET_LENGTH(payload_length, 3); - header[pos++] = FRAME_SET_LENGTH(payload_length, 2); - } - header[pos++] = FRAME_SET_LENGTH(payload_length, 1); - header[pos++] = FRAME_SET_LENGTH(payload_length, 0); + + /* Interrupt the pollset. */ + rv = apr_pollset_wakeup(state->pollset); + + if (rv != APR_SUCCESS) { + /* + * Couldn't wake up poll...? We can't return zero since we've + * already pushed the message, and it might actually be sent... + */ + /* TODO: log. */ } - ap_fwrite(of, state->obb, (const char *)header, pos); /* Header */ - if (payload_length > 0) { - if (ap_fwrite(of, state->obb, - (const char *)buffer, - buffer_size) == APR_SUCCESS) { /* Payload Data */ - written = buffer_size; - } + + /* Wait for the message to be written. */ + while (!msg.done && !state->closing) { + apr_thread_cond_wait(state->cond, state->mutex); } - if (ap_fflush(of, state->obb) != APR_SUCCESS) { - written = 0; + + if (msg.done) { + written = msg.written; } } + +send_unlock: apr_thread_mutex_unlock(state->mutex); } + return written; } + static void CALLBACK mod_websocket_plugin_close(const WebSocketServer * server) { @@ -381,26 +473,20 @@ static void CALLBACK mod_websocket_plugin_close(const WebSocketServer * /* * Read a buffer of data from the input stream. */ -static apr_size_t mod_websocket_read_block(request_rec *r, char *buffer, - apr_size_t bufsiz) +static apr_status_t mod_websocket_read_nonblock(request_rec *r, + apr_bucket_brigade *bb, + char *buffer, + apr_size_t *bufsiz) { apr_status_t rv; - apr_bucket_brigade *bb; - apr_size_t readbufsiz = 0; - - bb = apr_brigade_create(r->pool, r->connection->bucket_alloc); - if (bb != NULL) { - if ((rv = - ap_get_brigade(r->input_filters, bb, AP_MODE_READBYTES, - APR_BLOCK_READ, bufsiz)) == APR_SUCCESS) { - if ((rv = - apr_brigade_flatten(bb, buffer, &bufsiz)) == APR_SUCCESS) { - readbufsiz = bufsiz; - } - } - apr_brigade_destroy(bb); + + if ((rv = ap_get_brigade(r->input_filters, bb, AP_MODE_READBYTES, + APR_NONBLOCK_READ, *bufsiz)) == APR_SUCCESS) { + rv = apr_brigade_flatten(bb, buffer, bufsiz); + apr_brigade_cleanup(bb); } - return readbufsiz; + + return rv; } /* @@ -462,365 +548,513 @@ typedef struct _WebSocketFrameData unsigned int utf8_state; } WebSocketFrameData; -/* - * The data framing handler requires that the server state mutex is locked by - * the caller upon entering this function. It will be locked when leaving too. - */ -static void mod_websocket_data_framing(const WebSocketServer *server, - websocket_config_rec *conf, - void *plugin_private) +/* Variables that need to persist across calls to mod_websocket_handle_incoming */ +typedef struct { - WebSocketState *state = server->state; - request_rec *r = state->r; - apr_pool_t *pool = NULL; - apr_bucket_alloc_t *bucket_alloc; - apr_bucket_brigade *obb; - - /* We cannot use the same bucket allocator for the ouput bucket brigade - * obb as the one associated with the connection (r->connection->bucket_alloc) - * because the same bucket allocator cannot be used in two different - * threads, and we use the connection bucket allocator in this - * thread - see docs on apr_bucket_alloc_create(). This results in - * occasional core dumps. So create our own bucket allocator and pool - * for output thread bucket brigade. (Thanks to Alex Bligh -- abligh) - */ - - if ((apr_pool_create(&pool, r->pool) == APR_SUCCESS) && - ((bucket_alloc = apr_bucket_alloc_create(pool)) != NULL) && - ((obb = apr_brigade_create(pool, bucket_alloc)) != NULL)) { - unsigned char block[BLOCK_DATA_SIZE]; - apr_int64_t block_size; - apr_int64_t extension_bytes_remaining = 0; - apr_int64_t payload_length = 0; - apr_int64_t mask_offset = 0; - int framing_state = DATA_FRAMING_START; - int payload_length_bytes_remaining = 0; - int mask_index = 0, masking = 0; - unsigned char mask[4] = { 0, 0, 0, 0 }; - unsigned char fin = 0, opcode = 0xFF; - WebSocketFrameData control_frame = { 0, NULL, 1, 8, UTF8_VALID }; - WebSocketFrameData message_frame = { 0, NULL, 1, 0, UTF8_VALID }; - WebSocketFrameData *frame = &control_frame; - unsigned short status_code = STATUS_CODE_OK; - unsigned char status_code_buffer[2]; - - /* Allow the plugin to now write to the client */ - state->obb = obb; - apr_thread_mutex_unlock(state->mutex); + int framing_state; + unsigned short status_code; + /* XXX fin and opcode appear to be duplicated with frame; can they be removed? */ + unsigned char fin; + unsigned char opcode; + WebSocketFrameData control_frame; + WebSocketFrameData message_frame; + WebSocketFrameData *frame; + apr_int64_t payload_length; + apr_int64_t mask_offset; + apr_int64_t extension_bytes_remaining; + int payload_length_bytes_remaining; + int masking; + int mask_index; + unsigned char mask[4]; +} WebSocketReadState; + +static void mod_websocket_handle_incoming(const WebSocketServer *server, + unsigned char *block, + apr_int64_t block_size, + WebSocketReadState *state, + websocket_config_rec *conf, + void *plugin_private) +{ + apr_int64_t block_offset = 0; + + while (block_offset < block_size) { + switch (state->framing_state) { + case DATA_FRAMING_START: + /* + * Since we don't currently support any extensions, + * the reserve bits must be 0 + */ + if ((FRAME_GET_RSV1(block[block_offset]) != 0) || + (FRAME_GET_RSV2(block[block_offset]) != 0) || + (FRAME_GET_RSV3(block[block_offset]) != 0)) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + state->fin = FRAME_GET_FIN(block[block_offset]); + state->opcode = FRAME_GET_OPCODE(block[block_offset++]); - while ((framing_state != DATA_FRAMING_CLOSE) && - ((block_size = - mod_websocket_read_block(r, (char *)block, - sizeof(block))) > 0)) { - apr_int64_t block_offset = 0; + state->framing_state = DATA_FRAMING_PAYLOAD_LENGTH; - while (block_offset < block_size) { - switch (framing_state) { - case DATA_FRAMING_START: - /* - * Since we don't currently support any extensions, - * the reserve bits must be 0 - */ - if ((FRAME_GET_RSV1(block[block_offset]) != 0) || - (FRAME_GET_RSV2(block[block_offset]) != 0) || - (FRAME_GET_RSV3(block[block_offset]) != 0)) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; + if (state->opcode >= 0x8) { /* Control frame */ + if (state->fin) { + state->frame = &state->control_frame; + state->frame->opcode = state->opcode; + state->frame->utf8_state = UTF8_VALID; + } + else { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + } + else { /* Message frame */ + state->frame = &state->message_frame; + if (state->opcode) { + if (state->frame->fin) { + state->frame->opcode = state->opcode; + state->frame->utf8_state = UTF8_VALID; + } + else { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; break; } - fin = FRAME_GET_FIN(block[block_offset]); - opcode = FRAME_GET_OPCODE(block[block_offset++]); - - framing_state = DATA_FRAMING_PAYLOAD_LENGTH; + } + else if (state->frame->fin || + ((state->opcode = state->frame->opcode) == 0)) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + state->frame->fin = state->fin; + } + state->payload_length = 0; + state->payload_length_bytes_remaining = 0; - if (opcode >= 0x8) { /* Control frame */ - if (fin) { - frame = &control_frame; - frame->opcode = opcode; - frame->utf8_state = UTF8_VALID; - } - else { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } + if (block_offset >= block_size) { + break; /* Only break if we need more data */ + } + case DATA_FRAMING_PAYLOAD_LENGTH: + state->payload_length = (apr_int64_t) + FRAME_GET_PAYLOAD_LEN(block[block_offset]); + state->masking = FRAME_GET_MASK(block[block_offset++]); + + if (state->payload_length == 126) { + state->payload_length = 0; + state->payload_length_bytes_remaining = 2; + } + else if (state->payload_length == 127) { + state->payload_length = 0; + state->payload_length_bytes_remaining = 8; + } + else { + state->payload_length_bytes_remaining = 0; + } + if ((state->masking == 0) || /* Client-side mask is required */ + ((state->opcode >= 0x8) && /* Control opcodes cannot have a payload larger than 125 bytes */ + (state->payload_length_bytes_remaining != 0))) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + else { + state->framing_state = DATA_FRAMING_PAYLOAD_LENGTH_EXT; + } + if (block_offset >= block_size) { + break; /* Only break if we need more data */ + } + case DATA_FRAMING_PAYLOAD_LENGTH_EXT: + while ((state->payload_length_bytes_remaining > 0) && + (block_offset < block_size)) { + state->payload_length *= 256; + state->payload_length += block[block_offset++]; + state->payload_length_bytes_remaining--; + } + if (state->payload_length_bytes_remaining == 0) { + if ((state->payload_length < 0) || + (state->payload_length > conf->payload_limit)) { + /* Invalid payload length */ + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = (server->state->protocol_version >= 13) ? + STATUS_CODE_MESSAGE_TOO_LARGE : + STATUS_CODE_RESERVED; + break; + } + else if (state->masking != 0) { + state->framing_state = DATA_FRAMING_MASK; + } + else { + state->framing_state = DATA_FRAMING_EXTENSION_DATA; + break; + } + } + if (block_offset >= block_size) { + break; /* Only break if we need more data */ + } + case DATA_FRAMING_MASK: + while ((state->mask_index < 4) && (block_offset < block_size)) { + state->mask[state->mask_index++] = block[block_offset++]; + } + if (state->mask_index == 4) { + state->framing_state = DATA_FRAMING_EXTENSION_DATA; + state->mask_offset = 0; + state->mask_index = 0; + if ((state->mask[0] == 0) && (state->mask[1] == 0) && + (state->mask[2] == 0) && (state->mask[3] == 0)) { + state->masking = 0; + } + } + else { + break; + } + /* Fall through */ + case DATA_FRAMING_EXTENSION_DATA: + /* Deal with extension data when we support them -- FIXME */ + if (state->extension_bytes_remaining == 0) { + if (state->payload_length > 0) { + state->frame->application_data = (unsigned char *) + realloc(state->frame->application_data, + state->frame->application_data_offset + + state->payload_length); + if (state->frame->application_data == NULL) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = (server->state->protocol_version >= 13) ? + STATUS_CODE_INTERNAL_ERROR : + STATUS_CODE_GOING_AWAY; + break; } - else { /* Message frame */ - frame = &message_frame; - if (opcode) { - if (frame->fin) { - frame->opcode = opcode; - frame->utf8_state = UTF8_VALID; - } - else { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; + } + state->framing_state = DATA_FRAMING_APPLICATION_DATA; + } + /* Fall through */ + case DATA_FRAMING_APPLICATION_DATA: + { + apr_int64_t block_data_length; + apr_int64_t block_length = 0; + apr_uint64_t application_data_offset = + state->frame->application_data_offset; + unsigned char *application_data = + state->frame->application_data; + + block_length = block_size - block_offset; + block_data_length = + (state->payload_length > + block_length) ? block_length : state->payload_length; + + if (state->masking) { + apr_int64_t i; + + if (state->opcode == OPCODE_TEXT) { + unsigned int utf8_state = state->frame->utf8_state; + unsigned char c; + + for (i = 0; i < block_data_length; i++) { + c = block[block_offset++] ^ + state->mask[state->mask_offset++ & 3]; + utf8_state = + validate_utf8[utf8_state + c]; + if (utf8_state == UTF8_INVALID) { + state->payload_length = block_data_length; break; } + application_data + [application_data_offset++] = c; } - else if (frame->fin || - ((opcode = frame->opcode) == 0)) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - frame->fin = fin; - } - payload_length = 0; - payload_length_bytes_remaining = 0; - - if (block_offset >= block_size) { - break; /* Only break if we need more data */ - } - case DATA_FRAMING_PAYLOAD_LENGTH: - payload_length = (apr_int64_t) - FRAME_GET_PAYLOAD_LEN(block[block_offset]); - masking = FRAME_GET_MASK(block[block_offset++]); - - if (payload_length == 126) { - payload_length = 0; - payload_length_bytes_remaining = 2; - } - else if (payload_length == 127) { - payload_length = 0; - payload_length_bytes_remaining = 8; - } - else { - payload_length_bytes_remaining = 0; - } - if ((masking == 0) || /* Client-side mask is required */ - ((opcode >= 0x8) && /* Control opcodes cannot have a payload larger than 125 bytes */ - (payload_length_bytes_remaining != 0))) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; + state->frame->utf8_state = utf8_state; } else { - framing_state = DATA_FRAMING_PAYLOAD_LENGTH_EXT; - } - if (block_offset >= block_size) { - break; /* Only break if we need more data */ - } - case DATA_FRAMING_PAYLOAD_LENGTH_EXT: - while ((payload_length_bytes_remaining > 0) && - (block_offset < block_size)) { - payload_length *= 256; - payload_length += block[block_offset++]; - payload_length_bytes_remaining--; - } - if (payload_length_bytes_remaining == 0) { - if ((payload_length < 0) || - (payload_length > conf->payload_limit)) { - /* Invalid payload length */ - framing_state = DATA_FRAMING_CLOSE; - status_code = (state->protocol_version >= 13) ? - STATUS_CODE_MESSAGE_TOO_LARGE : - STATUS_CODE_RESERVED; - break; - } - else if (masking != 0) { - framing_state = DATA_FRAMING_MASK; - } - else { - framing_state = DATA_FRAMING_EXTENSION_DATA; - break; - } - } - if (block_offset >= block_size) { - break; /* Only break if we need more data */ - } - case DATA_FRAMING_MASK: - while ((mask_index < 4) && (block_offset < block_size)) { - mask[mask_index++] = block[block_offset++]; - } - if (mask_index == 4) { - framing_state = DATA_FRAMING_EXTENSION_DATA; - mask_offset = 0; - mask_index = 0; - if ((mask[0] == 0) && (mask[1] == 0) && - (mask[2] == 0) && (mask[3] == 0)) { - masking = 0; + /* Need to optimize the unmasking -- FIXME */ + for (i = 0; i < block_data_length; i++) { + application_data + [application_data_offset++] = + block[block_offset++] ^ + state->mask[state->mask_offset++ & 3]; } } - else { - break; - } - /* Fall through */ - case DATA_FRAMING_EXTENSION_DATA: - /* Deal with extension data when we support them -- FIXME */ - if (extension_bytes_remaining == 0) { - if (payload_length > 0) { - frame->application_data = (unsigned char *) - realloc(frame->application_data, - frame->application_data_offset + - payload_length); - if (frame->application_data == NULL) { - framing_state = DATA_FRAMING_CLOSE; - status_code = (state->protocol_version >= 13) ? - STATUS_CODE_INTERNAL_ERROR : - STATUS_CODE_GOING_AWAY; + } + else if (block_data_length > 0) { + memcpy(&application_data[application_data_offset], + &block[block_offset], block_data_length); + if (state->opcode == OPCODE_TEXT) { + apr_int64_t i, application_data_end = + application_data_offset + + block_data_length; + unsigned int utf8_state = state->frame->utf8_state; + + for (i = application_data_offset; + i < application_data_end; i++) { + utf8_state = + validate_utf8[utf8_state + + application_data[i]]; + if (utf8_state == UTF8_INVALID) { + state->payload_length = block_data_length; break; } } - framing_state = DATA_FRAMING_APPLICATION_DATA; + state->frame->utf8_state = utf8_state; } - /* Fall through */ - case DATA_FRAMING_APPLICATION_DATA: - { - apr_int64_t block_data_length; - apr_int64_t block_length = 0; - apr_uint64_t application_data_offset = - frame->application_data_offset; - unsigned char *application_data = - frame->application_data; - - block_length = block_size - block_offset; - block_data_length = - (payload_length > - block_length) ? block_length : payload_length; - - if (masking) { - apr_int64_t i; - - if (opcode == OPCODE_TEXT) { - unsigned int utf8_state = frame->utf8_state; - unsigned char c; - - for (i = 0; i < block_data_length; i++) { - c = block[block_offset++] ^ - mask[mask_offset++ & 3]; - utf8_state = - validate_utf8[utf8_state + c]; - if (utf8_state == UTF8_INVALID) { - payload_length = block_data_length; - break; - } - application_data - [application_data_offset++] = c; - } - frame->utf8_state = utf8_state; - } - else { - /* Need to optimize the unmasking -- FIXME */ - for (i = 0; i < block_data_length; i++) { - application_data - [application_data_offset++] = - block[block_offset++] ^ - mask[mask_offset++ & 3]; - } - } + application_data_offset += block_data_length; + block_offset += block_data_length; + } + state->payload_length -= block_data_length; + + if (state->payload_length == 0) { + int message_type = MESSAGE_TYPE_INVALID; + + switch (state->opcode) { + case OPCODE_TEXT: + if ((state->fin && + (state->frame->utf8_state != UTF8_VALID)) || + (state->frame->utf8_state == UTF8_INVALID)) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_INVALID_UTF8; } - else if (block_data_length > 0) { - memcpy(&application_data[application_data_offset], - &block[block_offset], block_data_length); - if (opcode == OPCODE_TEXT) { - apr_int64_t i, application_data_end = - application_data_offset + - block_data_length; - unsigned int utf8_state = frame->utf8_state; - - for (i = application_data_offset; - i < application_data_end; i++) { - utf8_state = - validate_utf8[utf8_state + - application_data[i]]; - if (utf8_state == UTF8_INVALID) { - payload_length = block_data_length; - break; - } - } - frame->utf8_state = utf8_state; - } - application_data_offset += block_data_length; - block_offset += block_data_length; + else { + message_type = MESSAGE_TYPE_TEXT; } - payload_length -= block_data_length; - - if (payload_length == 0) { - int message_type = MESSAGE_TYPE_INVALID; - - switch (opcode) { - case OPCODE_TEXT: - if ((fin && - (frame->utf8_state != UTF8_VALID)) || - (frame->utf8_state == UTF8_INVALID)) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_INVALID_UTF8; - } - else { - message_type = MESSAGE_TYPE_TEXT; - } - break; - case OPCODE_BINARY: - message_type = MESSAGE_TYPE_BINARY; - break; - case OPCODE_CLOSE: - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_OK; - break; - case OPCODE_PING: - mod_websocket_plugin_send(server, - MESSAGE_TYPE_PONG, - application_data, - application_data_offset); - break; - case OPCODE_PONG: - break; - default: - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - if (fin && (message_type != MESSAGE_TYPE_INVALID)) { - conf->plugin->on_message(plugin_private, - server, message_type, - application_data, - application_data_offset); - } - if (framing_state != DATA_FRAMING_CLOSE) { - framing_state = DATA_FRAMING_START; - - if (fin) { - if (frame->application_data != NULL) { - free(frame->application_data); - frame->application_data = NULL; - } - application_data_offset = 0; - } + break; + case OPCODE_BINARY: + message_type = MESSAGE_TYPE_BINARY; + break; + case OPCODE_CLOSE: + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_OK; + break; + case OPCODE_PING: + apr_thread_mutex_lock(server->state->mutex); + mod_websocket_send_internal(server->state, + MESSAGE_TYPE_PONG, + application_data, + application_data_offset); + apr_thread_mutex_unlock(server->state->mutex); + break; + case OPCODE_PONG: + break; + default: + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + if (state->fin && (message_type != MESSAGE_TYPE_INVALID)) { + conf->plugin->on_message(plugin_private, + server, message_type, + application_data, + application_data_offset); + } + if (state->framing_state != DATA_FRAMING_CLOSE) { + state->framing_state = DATA_FRAMING_START; + + if (state->fin) { + if (state->frame->application_data != NULL) { + free(state->frame->application_data); + state->frame->application_data = NULL; } + application_data_offset = 0; } - frame->application_data_offset = - application_data_offset; } - break; - case DATA_FRAMING_CLOSE: - block_offset = block_size; - break; - default: - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; } + state->frame->application_data_offset = + application_data_offset; + } + break; + case DATA_FRAMING_CLOSE: + block_offset = block_size; + break; + default: + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + } +} + +static void mod_websocket_handle_outgoing(const WebSocketServer *server, + WebSocketMessageData *msg) +{ + apr_thread_mutex_lock(server->state->mutex); + msg->written = mod_websocket_send_internal(server->state, msg->type, + msg->buffer, msg->buffer_size); + + /* + * Notify plugin_send() that the message has been sent. + * + * XXX Wake up _all_ the waiting threads, since we don't know which one owns + * this message. This is contentious if there are a lot of threads writing + * in parallel. + */ + msg->done = 1; + apr_thread_cond_broadcast(server->state->cond); + + apr_thread_mutex_unlock(server->state->mutex); +} + +/* + * Compatibility wrapper for ap_get_conn_socket(), which doesn't exist in Apache + * 2.2. + */ +static apr_socket_t *get_conn_socket(conn_rec *conn) +{ +#if AP_MODULE_MAGIC_AT_LEAST(20110605,2) + return ap_get_conn_socket(conn); +#else + return ap_get_module_config(conn->conn_config, &core_module); +#endif +} + +/* + * The data framing handler requires that the server state mutex is locked by + * the caller upon entering this function. It will be locked when leaving too. + * + * The framing loop is the only place where data is written to or read from the + * socket via the bucket brigades, to prevent simultaneous access to the + * brigades. Having a read-only thread and a write-only thread isn't good + * enough, because filters (mod_ssl in particular) may read from the socket + * during a write and vice-versa. + * + * The framing loop runs on the main request thread given to us by Apache. + * Outgoing messages queued from another thread (by mod_websocket_plugin_send()) + * are dequeued and written here. + */ +static void mod_websocket_data_framing(const WebSocketServer *server, + websocket_config_rec *conf, + void *plugin_private) +{ + WebSocketState *state = server->state; + request_rec *r = state->r; + apr_bucket_brigade *ibb, *obb; + apr_pollset_t *pollset; + apr_pollfd_t pollfd = { 0 }; + const apr_pollfd_t *signalled; + apr_int32_t pollcnt; + apr_queue_t * queue; + + if (((ibb = apr_brigade_create(r->pool, r->connection->bucket_alloc)) != NULL) && + ((obb = apr_brigade_create(r->pool, r->connection->bucket_alloc)) != NULL) && + (apr_pollset_create(&pollset, 1, r->pool, APR_POLLSET_WAKEABLE) == APR_SUCCESS) && + (apr_queue_create(&queue, QUEUE_CAPACITY, r->pool) == APR_SUCCESS)) { + unsigned char block[BLOCK_DATA_SIZE]; + apr_int64_t block_size; + unsigned char status_code_buffer[2]; + WebSocketReadState read_state = { 0 }; + + read_state.framing_state = DATA_FRAMING_START; + read_state.status_code = STATUS_CODE_OK; + read_state.control_frame.fin = 1; + read_state.control_frame.opcode = 8; + read_state.control_frame.utf8_state = UTF8_VALID; + read_state.message_frame.fin = 1; + read_state.message_frame.opcode = 0; + read_state.message_frame.utf8_state = UTF8_VALID; + read_state.frame = &read_state.control_frame; + read_state.opcode = 0xFF; + + state->queue = queue; + + /* Initialize the pollset */ + pollfd.p = r->pool; + pollfd.desc_type = APR_POLL_SOCKET; + pollfd.reqevents = APR_POLLIN; + pollfd.desc.s = get_conn_socket(state->r->connection); + apr_pollset_add(pollset, &pollfd); + + state->pollset = pollset; + + /* Allow the plugin to now write to the client */ + state->obb = obb; + apr_thread_mutex_unlock(state->mutex); + + /* + * Main loop, inspired by mod_spdy. Alternate between data coming from + * the client and data coming from the server. Only block in poll() if + * there is no work to be done for either side. + */ + while ((read_state.framing_state != DATA_FRAMING_CLOSE)) { + apr_status_t rv; + apr_interval_time_t timeout; + WebSocketMessageData *msg; + int work_done = 0; + + /* Check to see if there is any data to read. */ + block_size = sizeof(block); + rv = mod_websocket_read_nonblock(r, ibb, (char *)block, &block_size); + + if (rv == APR_SUCCESS) { + mod_websocket_handle_incoming(server, block, block_size, + &read_state, conf, plugin_private); + work_done = 1; + } + else if (!APR_STATUS_IS_EAGAIN(rv)) { + read_state.status_code = STATUS_CODE_INTERNAL_ERROR; + break; + } + + /* Check to see if there is any data to write. */ + do { + void *el; + rv = apr_queue_trypop(state->queue, &el); + msg = el; + } while (APR_STATUS_IS_EINTR(rv)); + + if (rv == APR_SUCCESS) { + mod_websocket_handle_outgoing(server, msg); + work_done = 1; + } + else if (!APR_STATUS_IS_EAGAIN(rv)) { + read_state.status_code = STATUS_CODE_INTERNAL_ERROR; + break; + } + + /* + * If there's nothing to do, wait for new work to come in. + * + * Because Windows cannot poll on both a file pipe and a socket, + * plugin_send() uses apr_pollset_wakeup() to signal that new data + * is available to write. This is lossy (multiple threads calling + * wakeup() will result in only one wakeup here) so it's important + * that we do not block until state->queue has emptied. Otherwise + * it's possible to lose messages in the queue. + * + * NOTE: The wakeup pipe is drained only during apr_pollset_poll(), + * so we call it each iteration to avoid filling it up. We only + * block in poll() (negative timeout) if there was no work done + * during the current iteration. + */ + timeout = work_done ? 0 : -1; + rv = apr_pollset_poll(state->pollset, timeout, &pollcnt, &signalled); + + if ((rv != APR_SUCCESS) && !APR_STATUS_IS_EINTR(rv) && + !APR_STATUS_IS_TIMEUP(rv)) { + read_state.status_code = STATUS_CODE_INTERNAL_ERROR; + break; } } - if (message_frame.application_data != NULL) { - free(message_frame.application_data); + if (read_state.message_frame.application_data != NULL) { + free(read_state.message_frame.application_data); } - if (control_frame.application_data != NULL) { - free(control_frame.application_data); + if (read_state.control_frame.application_data != NULL) { + free(read_state.control_frame.application_data); } /* Send server-side closing handshake */ - status_code_buffer[0] = (status_code >> 8) & 0xFF; - status_code_buffer[1] = status_code & 0xFF; - mod_websocket_plugin_send(server, MESSAGE_TYPE_CLOSE, - status_code_buffer, - sizeof(status_code_buffer)); + status_code_buffer[0] = (read_state.status_code >> 8) & 0xFF; + status_code_buffer[1] = read_state.status_code & 0xFF; - /* We are done with the bucket brigade */ apr_thread_mutex_lock(state->mutex); + mod_websocket_send_internal(state, MESSAGE_TYPE_CLOSE, + status_code_buffer, + sizeof(status_code_buffer)); + + /* We are done with the bucket brigades */ state->obb = NULL; + apr_brigade_destroy(ibb); apr_brigade_destroy(obb); + + state->pollset = NULL; + apr_pollset_destroy(pollset); + + state->queue = NULL; + apr_queue_term(queue); } } @@ -888,8 +1122,10 @@ static int mod_websocket_method_handler(request_rec *r) &websocket_module); if ((conf != NULL) && (conf->plugin != NULL)) { - WebSocketState state = - { r, NULL, NULL, NULL, 0, protocol_version }; + WebSocketState state = { + r, NULL, apr_os_thread_current(), NULL, NULL, NULL, 0, + protocol_version, NULL, NULL + }; WebSocketServer server = { sizeof(WebSocketServer), 1, &state, mod_websocket_request, mod_websocket_header_get, @@ -946,6 +1182,8 @@ static int mod_websocket_method_handler(request_rec *r) apr_thread_mutex_create(&state.mutex, APR_THREAD_MUTEX_DEFAULT, r->pool); + apr_thread_cond_create(&state.cond, r->pool); + apr_thread_mutex_lock(state.mutex); /* @@ -959,9 +1197,8 @@ static int mod_websocket_method_handler(request_rec *r) * Now that the connection has been established, * disable the socket timeout */ - apr_socket_timeout_set(ap_get_module_config - (r->connection->conn_config, - &core_module), -1); + apr_socket_timeout_set(get_conn_socket(r->connection), + -1); /* Set response status code and status line */ r->status = HTTP_SWITCHING_PROTOCOLS; @@ -974,6 +1211,9 @@ static int mod_websocket_method_handler(request_rec *r) mod_websocket_data_framing(&server, conf, plugin_private); + /* Wake up any waiting plugin_sends before closing */ + apr_thread_cond_broadcast(state.cond); + apr_thread_mutex_unlock(state.mutex); /* Tell the plugin that we are disconnecting */ @@ -999,6 +1239,7 @@ static int mod_websocket_method_handler(request_rec *r) /* Close the connection */ ap_lingering_close(r->connection); + apr_thread_cond_destroy(state.cond); apr_thread_mutex_destroy(state.mutex); return OK;