diff --git a/CHANGELOG.md b/CHANGELOG.md index aeb2d20..1386a79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to ext-websocket are documented here. +## Unreleased + +### Added + +- Added `WebSocket\Server::onHandshake()` with `WebSocket\Request`, `WebSocket\HandshakeResponse`, and `WebSocket\HandshakeException` for pre-upgrade handshake validation. + ## 1.2.1 - 2026-05-23 ### Added diff --git a/README.md b/README.md index b5f55bb..2607c89 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,7 @@ Methods: | `__construct(ServerOptions\|array $options = [])` | Create a server | | `listen(string $host, int $port): void` | Bind address for `run()` | | `subprotocols(string ...$protocols): void` | Configure supported `Sec-WebSocket-Protocol` tokens | +| `onHandshake(Closure $handler): void` | Accept or reject valid HTTP Upgrade requests before `101 Switching Protocols` | | `onOpen(Closure $handler): void` | Register upgraded connection callback | | `onMessage(Closure $handler): void` | Register text/binary message callback | | `onClose(Closure $handler): void` | Register close callback | @@ -124,6 +125,43 @@ Methods: | `stop(): void` | Request shutdown | | `getDriver(): string` | Return selected I/O driver | +Handshake callbacks receive a `WebSocket\Request`. Return normally to continue the WebSocket upgrade, or throw `WebSocket\HandshakeException` to reject it before `101 Switching Protocols` is sent: + +```php +$server->onHandshake(static function (WebSocket\Request $request): void { + if ($request->header('Origin') !== 'https://app.test') { + throw new WebSocket\HandshakeException( + new WebSocket\HandshakeResponse(403, ['X-Reject' => 'origin']) + ); + } +}); +``` + +### `WebSocket\Request` + +| Method / property | Description | +|---|---| +| `header(string $name): ?string` | Return a case-insensitive request header value | +| `readonly string $method` | HTTP request method | +| `readonly string $target` | HTTP request target | +| `readonly array $headers` | Lower-case request headers | + +### `WebSocket\HandshakeResponse` + +| Method / property | Description | +|---|---| +| `__construct(int $status = 403, array $headers = [], string $body = '')` | Create a custom handshake rejection response | +| `readonly int $status` | HTTP status code | +| `readonly array $headers` | HTTP response headers | +| `readonly string $body` | HTTP response body | + +### `WebSocket\HandshakeException` + +| Method / property | Description | +|---|---| +| `__construct(?HandshakeResponse $response = null)` | Create a handshake rejection exception; defaults to `403 Forbidden` | +| `readonly HandshakeResponse $response` | HTTP response sent before closing the connection | + ### `WebSocket\Connection` | Method / property | Description | diff --git a/php_websocket.h b/php_websocket.h index 9e9e57b..c671180 100644 --- a/php_websocket.h +++ b/php_websocket.h @@ -90,6 +90,7 @@ typedef struct _websocket_server_object { zval options; zval subprotocols; zval on_open; + zval on_handshake; zval on_message; zval on_close; zval on_error; @@ -142,6 +143,9 @@ typedef struct _websocket_connection_object { extern zend_class_entry *websocket_server_ce; extern zend_class_entry *websocket_server_options_ce; +extern zend_class_entry *websocket_request_ce; +extern zend_class_entry *websocket_handshake_response_ce; +extern zend_class_entry *websocket_handshake_exception_ce; extern zend_class_entry *websocket_connection_ce; extern zend_class_entry *websocket_message_type_ce; extern zend_class_entry *websocket_frame_ce; diff --git a/tests/001-contracts.phpt b/tests/001-contracts.phpt index 170255b..bf02b5d 100644 --- a/tests/001-contracts.phpt +++ b/tests/001-contracts.phpt @@ -6,6 +6,9 @@ websocket isVariadic()); +var_dump((new ReflectionMethod(WebSocket\Request::class, 'header'))->getReturnType()->allowsNull()); +var_dump((new ReflectionMethod(WebSocket\HandshakeResponse::class, '__construct'))->getNumberOfParameters()); +var_dump((new ReflectionMethod(WebSocket\HandshakeException::class, '__construct'))->getNumberOfParameters()); var_dump((new ReflectionMethod(WebSocket\Connection::class, 'send'))->getNumberOfParameters()); var_dump((new ReflectionProperty(WebSocket\Connection::class, 'subprotocol'))->getType()->allowsNull()); var_dump((new ReflectionMethod(WebSocket\ServerOptions::class, '__construct'))->getNumberOfParameters()); @@ -41,9 +48,16 @@ try { echo $e->getMessage(), "\n"; } $server->onOpen(static function () {}); +$server->onHandshake(static function () {}); $server->onMessage(static function () {}); $server->onClose(static function () {}); $server->onError(static function () {}); +$response = new WebSocket\HandshakeResponse(401, ['X-Test' => 'ok'], 'nope'); +var_dump($response->status); +var_dump($response->headers); +var_dump($response->body); +$exception = new WebSocket\HandshakeException($response); +var_dump($exception->response === $response); var_dump(in_array($server->getDriver(), ['kqueue', 'epoll', 'poll', 'select'], true)); ?> --EXPECT-- @@ -54,10 +68,17 @@ bool(true) bool(true) bool(true) bool(true) +bool(true) +bool(true) +bool(true) bool(false) bool(false) bool(true) bool(true) +bool(true) +bool(true) +int(3) +int(1) int(2) bool(true) int(5) @@ -70,4 +91,11 @@ WebSocket\ServerOptions::__construct(): Argument #1 ($maxMessageSize) must be at int(3) int(2) WebSocket\Server::subprotocols(): Argument #1 must be a valid WebSocket subprotocol token +int(401) +array(1) { + ["X-Test"]=> + string(2) "ok" +} +string(4) "nope" +bool(true) bool(true) diff --git a/tests/016-server-handshake-hook.phpt b/tests/016-server-handshake-hook.phpt new file mode 100644 index 0000000..245de85 --- /dev/null +++ b/tests/016-server-handshake-hook.phpt @@ -0,0 +1,200 @@ +--TEST-- +WebSocket\Server validates handshakes before upgrade +--EXTENSIONS-- +websocket +--SKIPIF-- + +--FILE-- +listen('127.0.0.1', PORT_PLACEHOLDER); + +$server->onHandshake(static function (Request $request): void { + file_put_contents(EVENTS_PLACEHOLDER, $request->target . ':' . ($request->header('Origin') ?? '(none)') . "\n", FILE_APPEND); + + if ($request->header('Origin') === 'https://app.test') { + return; + } + + if ($request->target === '/custom') { + throw new HandshakeException(new HandshakeResponse(401, ['X-Reject' => 'origin'], 'nope')); + } + + throw new HandshakeException(); +}); + +$server->onOpen(static function (Connection $connection) use ($server): void { + file_put_contents(EVENTS_PLACEHOLDER, "open\n", FILE_APPEND); + $server->stop(); +}); + +$server->run(); +file_put_contents(EVENTS_PLACEHOLDER, "returned\n", FILE_APPEND); +PHP; + +$serverCode = str_replace( + ['PORT_PLACEHOLDER', 'EVENTS_PLACEHOLDER'], + [(string) $port, var_export($eventsFile, true)], + $serverCode, +); +file_put_contents($serverFile, $serverCode); + +$process = proc_open( + [PHP_BINARY, '-n', '-d', 'extension=' . $extension, $serverFile], + [ + 1 => ['pipe', 'w'], + 2 => ['pipe', 'w'], + ], + $pipes, +); + +if (!is_resource($process)) { + echo "cannot start server\n"; + exit; +} + +$connect = static function () use ($port, $process): mixed { + $client = false; + $deadline = microtime(true) + 5.0; + + do { + $client = @stream_socket_client('tcp://127.0.0.1:' . $port, $errno, $errstr, 0.1); + if ($client !== false) { + return $client; + } + + $status = proc_get_status($process); + if (!$status['running']) { + break; + } + + usleep(10000); + } while (microtime(true) < $deadline); + + return false; +}; + +$handshake = static function (string $target, ?string $origin) use ($connect, $port): string { + $client = $connect(); + if ($client === false) { + return ''; + } + + $lines = [ + 'GET ' . $target . ' HTTP/1.1', + 'Host: 127.0.0.1:' . $port, + 'Upgrade: websocket', + 'Connection: Upgrade', + 'Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version: 13', + ]; + + if ($origin !== null) { + $lines[] = 'Origin: ' . $origin; + } + + $lines[] = ''; + $lines[] = ''; + + fwrite($client, implode("\r\n", $lines)); + stream_set_timeout($client, 1); + $response = fread($client, 4096); + fclose($client); + + return $response; +}; + +$forbiddenResponse = $handshake('/chat', 'https://evil.test'); +$customResponse = $handshake('/custom', null); +$acceptedResponse = $handshake('/chat', 'https://app.test'); + +$deadline = microtime(true) + 5.0; +do { + $status = proc_get_status($process); + if (!$status['running']) { + break; + } + + usleep(10000); +} while (microtime(true) < $deadline); + +$status = proc_get_status($process); +if ($status['running']) { + proc_terminate($process); +} + +$stdout = stream_get_contents($pipes[1]); +$stderr = stream_get_contents($pipes[2]); +fclose($pipes[1]); +fclose($pipes[2]); +proc_close($process); + +$events = file_exists($eventsFile) ? file($eventsFile, FILE_IGNORE_NEW_LINES) : []; + +var_dump(str_contains($forbiddenResponse, "HTTP/1.1 403 Forbidden\r\n")); +var_dump(str_contains($customResponse, "HTTP/1.1 401 Unauthorized\r\n")); +var_dump(str_contains($customResponse, "X-Reject: origin\r\n")); +var_dump(str_ends_with($customResponse, "nope")); +var_dump(str_contains($acceptedResponse, "HTTP/1.1 101 Switching Protocols\r\n")); +var_dump($events); +var_dump($stdout === ''); +var_dump($stderr === ''); + +@unlink($eventsFile); +@unlink($serverFile); +@rmdir($tmpDir); +?> +--EXPECT-- +bool(true) +bool(true) +bool(true) +bool(true) +bool(true) +array(5) { + [0]=> + string(23) "/chat:https://evil.test" + [1]=> + string(14) "/custom:(none)" + [2]=> + string(22) "/chat:https://app.test" + [3]=> + string(4) "open" + [4]=> + string(8) "returned" +} +bool(true) +bool(true) diff --git a/websocket.c b/websocket.c index c3237fa..060c811 100644 --- a/websocket.c +++ b/websocket.c @@ -12,6 +12,9 @@ ZEND_DECLARE_MODULE_GLOBALS(websocket) zend_class_entry *websocket_server_ce; zend_class_entry *websocket_server_options_ce; +zend_class_entry *websocket_request_ce; +zend_class_entry *websocket_handshake_response_ce; +zend_class_entry *websocket_handshake_exception_ce; zend_class_entry *websocket_connection_ce; zend_class_entry *websocket_message_type_ce; zend_class_entry *websocket_frame_ce; diff --git a/websocket.stub.php b/websocket.stub.php index 348f6f6..b3d97f6 100644 --- a/websocket.stub.php +++ b/websocket.stub.php @@ -50,6 +50,17 @@ public function listen(string $host, int $port): void {} */ public function subprotocols(string ...$protocols): void {} + /** + * Register a callback called after a valid HTTP Upgrade request is parsed + * and before the WebSocket handshake response is sent. + * + * Return normally to accept the handshake. Throw HandshakeException to + * reject it before the WebSocket upgrade response is sent. + * + * @param \Closure(Request):void $handler + */ + public function onHandshake(\Closure $handler): void {} + /** * Register a callback called after a successful HTTP Upgrade. * @@ -158,6 +169,77 @@ public function __construct( ) {} } +/** + * HTTP Upgrade request passed to WebSocket\Server::onHandshake(). + */ +final class Request +{ + /** + * Request method. + */ + public readonly string $method; + + /** + * Request target from the HTTP request line. + */ + public readonly string $target; + + /** + * Lower-case HTTP headers. + * + * @var array + */ + public readonly array $headers; + + /** + * Return a header value by case-insensitive name, or null when absent. + */ + public function header(string $name): ?string {} +} + +/** + * HTTP response carried by HandshakeException to reject upgrade. + */ +final class HandshakeResponse +{ + /** + * HTTP status code. + */ + public readonly int $status; + + /** + * HTTP response headers. + * + * @var array + */ + public readonly array $headers; + + /** + * HTTP response body. + */ + public readonly string $body; + + /** + * @param array $headers + * + * @throws \ValueError If the status or any header is invalid. + */ + public function __construct(int $status = 403, array $headers = [], string $body = '') {} +} + +/** + * Exception thrown from WebSocket\Server::onHandshake() to reject upgrade. + */ +final class HandshakeException extends \Exception +{ + /** + * HTTP response sent before closing the connection. + */ + public readonly HandshakeResponse $response; + + public function __construct(?HandshakeResponse $response = null) {} +} + /** * Runtime connection accepted by WebSocket\Server. */ diff --git a/websocket_arginfo.h b/websocket_arginfo.h index 56ba822..ea2911b 100644 --- a/websocket_arginfo.h +++ b/websocket_arginfo.h @@ -1,5 +1,5 @@ /* This is a generated file, edit the .stub.php file instead. - * Stub hash: 3b15a096ef5cc7eadc01a0f97ff128ced381089e */ + * Stub hash: 6adb6001500f3acf211a163aa05ad143a2628fbc */ ZEND_BEGIN_ARG_INFO_EX(arginfo_class_WebSocket_Server___construct, 0, 0, 0) ZEND_ARG_OBJ_TYPE_MASK(0, options, WebSocket\\ServerOptions, MAY_BE_ARRAY, "[]") @@ -14,15 +14,17 @@ ZEND_BEGIN_ARG_WITH_RETURN_TYPE_INFO_EX(arginfo_class_WebSocket_Server_subprotoc ZEND_ARG_VARIADIC_TYPE_INFO(0, protocols, IS_STRING, 0) ZEND_END_ARG_INFO() -ZEND_BEGIN_ARG_WITH_RETURN_TYPE_INFO_EX(arginfo_class_WebSocket_Server_onOpen, 0, 1, IS_VOID, 0) +ZEND_BEGIN_ARG_WITH_RETURN_TYPE_INFO_EX(arginfo_class_WebSocket_Server_onHandshake, 0, 1, IS_VOID, 0) ZEND_ARG_OBJ_INFO(0, handler, Closure, 0) ZEND_END_ARG_INFO() -#define arginfo_class_WebSocket_Server_onMessage arginfo_class_WebSocket_Server_onOpen +#define arginfo_class_WebSocket_Server_onOpen arginfo_class_WebSocket_Server_onHandshake -#define arginfo_class_WebSocket_Server_onClose arginfo_class_WebSocket_Server_onOpen +#define arginfo_class_WebSocket_Server_onMessage arginfo_class_WebSocket_Server_onHandshake -#define arginfo_class_WebSocket_Server_onError arginfo_class_WebSocket_Server_onOpen +#define arginfo_class_WebSocket_Server_onClose arginfo_class_WebSocket_Server_onHandshake + +#define arginfo_class_WebSocket_Server_onError arginfo_class_WebSocket_Server_onHandshake ZEND_BEGIN_ARG_WITH_RETURN_TYPE_INFO_EX(arginfo_class_WebSocket_Server_run, 0, 0, IS_VOID, 0) ZEND_END_ARG_INFO() @@ -40,6 +42,20 @@ ZEND_BEGIN_ARG_INFO_EX(arginfo_class_WebSocket_ServerOptions___construct, 0, 0, ZEND_ARG_TYPE_INFO_WITH_DEFAULT_VALUE(0, idleTimeoutMs, IS_LONG, 0, "120000") ZEND_END_ARG_INFO() +ZEND_BEGIN_ARG_WITH_RETURN_TYPE_INFO_EX(arginfo_class_WebSocket_Request_header, 0, 1, IS_STRING, 1) + ZEND_ARG_TYPE_INFO(0, name, IS_STRING, 0) +ZEND_END_ARG_INFO() + +ZEND_BEGIN_ARG_INFO_EX(arginfo_class_WebSocket_HandshakeResponse___construct, 0, 0, 0) + ZEND_ARG_TYPE_INFO_WITH_DEFAULT_VALUE(0, status, IS_LONG, 0, "403") + ZEND_ARG_TYPE_INFO_WITH_DEFAULT_VALUE(0, headers, IS_ARRAY, 0, "[]") + ZEND_ARG_TYPE_INFO_WITH_DEFAULT_VALUE(0, body, IS_STRING, 0, "\'\'") +ZEND_END_ARG_INFO() + +ZEND_BEGIN_ARG_INFO_EX(arginfo_class_WebSocket_HandshakeException___construct, 0, 0, 0) + ZEND_ARG_OBJ_INFO_WITH_DEFAULT_VALUE(0, response, WebSocket\\HandshakeResponse, 1, "null") +ZEND_END_ARG_INFO() + ZEND_BEGIN_ARG_WITH_RETURN_TYPE_INFO_EX(arginfo_class_WebSocket_Connection_send, 0, 1, IS_VOID, 0) ZEND_ARG_TYPE_INFO(0, payload, IS_STRING, 0) ZEND_ARG_OBJ_INFO_WITH_DEFAULT_VALUE(0, type, WebSocket\\MessageType, 0, "WebSocket\\MessageType::Text") @@ -89,6 +105,7 @@ ZEND_END_ARG_INFO() ZEND_METHOD(WebSocket_Server, __construct); ZEND_METHOD(WebSocket_Server, listen); ZEND_METHOD(WebSocket_Server, subprotocols); +ZEND_METHOD(WebSocket_Server, onHandshake); ZEND_METHOD(WebSocket_Server, onOpen); ZEND_METHOD(WebSocket_Server, onMessage); ZEND_METHOD(WebSocket_Server, onClose); @@ -97,6 +114,9 @@ ZEND_METHOD(WebSocket_Server, run); ZEND_METHOD(WebSocket_Server, stop); ZEND_METHOD(WebSocket_Server, getDriver); ZEND_METHOD(WebSocket_ServerOptions, __construct); +ZEND_METHOD(WebSocket_Request, header); +ZEND_METHOD(WebSocket_HandshakeResponse, __construct); +ZEND_METHOD(WebSocket_HandshakeException, __construct); ZEND_METHOD(WebSocket_Connection, send); ZEND_METHOD(WebSocket_Connection, close); ZEND_METHOD(WebSocket_Connection, isOpen); @@ -112,6 +132,7 @@ static const zend_function_entry class_WebSocket_Server_methods[] = { ZEND_ME(WebSocket_Server, __construct, arginfo_class_WebSocket_Server___construct, ZEND_ACC_PUBLIC) ZEND_ME(WebSocket_Server, listen, arginfo_class_WebSocket_Server_listen, ZEND_ACC_PUBLIC) ZEND_ME(WebSocket_Server, subprotocols, arginfo_class_WebSocket_Server_subprotocols, ZEND_ACC_PUBLIC) + ZEND_ME(WebSocket_Server, onHandshake, arginfo_class_WebSocket_Server_onHandshake, ZEND_ACC_PUBLIC) ZEND_ME(WebSocket_Server, onOpen, arginfo_class_WebSocket_Server_onOpen, ZEND_ACC_PUBLIC) ZEND_ME(WebSocket_Server, onMessage, arginfo_class_WebSocket_Server_onMessage, ZEND_ACC_PUBLIC) ZEND_ME(WebSocket_Server, onClose, arginfo_class_WebSocket_Server_onClose, ZEND_ACC_PUBLIC) @@ -127,6 +148,21 @@ static const zend_function_entry class_WebSocket_ServerOptions_methods[] = { ZEND_FE_END }; +static const zend_function_entry class_WebSocket_Request_methods[] = { + ZEND_ME(WebSocket_Request, header, arginfo_class_WebSocket_Request_header, ZEND_ACC_PUBLIC) + ZEND_FE_END +}; + +static const zend_function_entry class_WebSocket_HandshakeResponse_methods[] = { + ZEND_ME(WebSocket_HandshakeResponse, __construct, arginfo_class_WebSocket_HandshakeResponse___construct, ZEND_ACC_PUBLIC) + ZEND_FE_END +}; + +static const zend_function_entry class_WebSocket_HandshakeException_methods[] = { + ZEND_ME(WebSocket_HandshakeException, __construct, arginfo_class_WebSocket_HandshakeException___construct, ZEND_ACC_PUBLIC) + ZEND_FE_END +}; + static const zend_function_entry class_WebSocket_Connection_methods[] = { ZEND_ME(WebSocket_Connection, send, arginfo_class_WebSocket_Connection_send, ZEND_ACC_PUBLIC) ZEND_ME(WebSocket_Connection, close, arginfo_class_WebSocket_Connection_close, ZEND_ACC_PUBLIC) @@ -203,6 +239,79 @@ static zend_class_entry *register_class_WebSocket_ServerOptions(void) return class_entry; } +static zend_class_entry *register_class_WebSocket_Request(void) +{ + zend_class_entry ce, *class_entry; + + INIT_NS_CLASS_ENTRY(ce, "WebSocket", "Request", class_WebSocket_Request_methods); + class_entry = zend_register_internal_class_with_flags(&ce, NULL, ZEND_ACC_FINAL); + + zval property_method_default_value; + ZVAL_UNDEF(&property_method_default_value); + zend_string *property_method_name = zend_string_init("method", sizeof("method") - 1, 1); + zend_declare_typed_property(class_entry, property_method_name, &property_method_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_MASK(MAY_BE_STRING)); + zend_string_release(property_method_name); + + zval property_target_default_value; + ZVAL_UNDEF(&property_target_default_value); + zend_string *property_target_name = zend_string_init("target", sizeof("target") - 1, 1); + zend_declare_typed_property(class_entry, property_target_name, &property_target_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_MASK(MAY_BE_STRING)); + zend_string_release(property_target_name); + + zval property_headers_default_value; + ZVAL_UNDEF(&property_headers_default_value); + zend_string *property_headers_name = zend_string_init("headers", sizeof("headers") - 1, 1); + zend_declare_typed_property(class_entry, property_headers_name, &property_headers_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_MASK(MAY_BE_ARRAY)); + zend_string_release(property_headers_name); + + return class_entry; +} + +static zend_class_entry *register_class_WebSocket_HandshakeResponse(void) +{ + zend_class_entry ce, *class_entry; + + INIT_NS_CLASS_ENTRY(ce, "WebSocket", "HandshakeResponse", class_WebSocket_HandshakeResponse_methods); + class_entry = zend_register_internal_class_with_flags(&ce, NULL, ZEND_ACC_FINAL); + + zval property_status_default_value; + ZVAL_UNDEF(&property_status_default_value); + zend_string *property_status_name = zend_string_init("status", sizeof("status") - 1, 1); + zend_declare_typed_property(class_entry, property_status_name, &property_status_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_MASK(MAY_BE_LONG)); + zend_string_release(property_status_name); + + zval property_headers_default_value; + ZVAL_UNDEF(&property_headers_default_value); + zend_string *property_headers_name = zend_string_init("headers", sizeof("headers") - 1, 1); + zend_declare_typed_property(class_entry, property_headers_name, &property_headers_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_MASK(MAY_BE_ARRAY)); + zend_string_release(property_headers_name); + + zval property_body_default_value; + ZVAL_UNDEF(&property_body_default_value); + zend_string *property_body_name = zend_string_init("body", sizeof("body") - 1, 1); + zend_declare_typed_property(class_entry, property_body_name, &property_body_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_MASK(MAY_BE_STRING)); + zend_string_release(property_body_name); + + return class_entry; +} + +static zend_class_entry *register_class_WebSocket_HandshakeException(zend_class_entry *class_entry_Exception) +{ + zend_class_entry ce, *class_entry; + + INIT_NS_CLASS_ENTRY(ce, "WebSocket", "HandshakeException", class_WebSocket_HandshakeException_methods); + class_entry = zend_register_internal_class_with_flags(&ce, class_entry_Exception, ZEND_ACC_FINAL); + + zval property_response_default_value; + ZVAL_UNDEF(&property_response_default_value); + zend_string *property_response_name = zend_string_init("response", sizeof("response") - 1, 1); + zend_string *property_response_class_WebSocket_HandshakeResponse = zend_string_init("WebSocket\\HandshakeResponse", sizeof("WebSocket\\HandshakeResponse")-1, 1); + zend_declare_typed_property(class_entry, property_response_name, &property_response_default_value, ZEND_ACC_PUBLIC|ZEND_ACC_READONLY, NULL, (zend_type) ZEND_TYPE_INIT_CLASS(property_response_class_WebSocket_HandshakeResponse, 0, 0)); + zend_string_release(property_response_name); + + return class_entry; +} + static zend_class_entry *register_class_WebSocket_Connection(void) { zend_class_entry ce, *class_entry; diff --git a/websocket_server.c b/websocket_server.c index 7e1046a..038f8d2 100644 --- a/websocket_server.c +++ b/websocket_server.c @@ -11,6 +11,7 @@ #include "php_websocket.h" #include "php_websocket_compat.h" #include "websocket_arginfo.h" +#include "Zend/zend_exceptions.h" #include @@ -31,6 +32,7 @@ static zend_object *websocket_server_create_object(zend_class_entry *ce) ZVAL_UNDEF(&intern->options); ZVAL_UNDEF(&intern->subprotocols); ZVAL_UNDEF(&intern->on_open); + ZVAL_UNDEF(&intern->on_handshake); ZVAL_UNDEF(&intern->on_message); ZVAL_UNDEF(&intern->on_close); ZVAL_UNDEF(&intern->on_error); @@ -59,6 +61,7 @@ static void websocket_server_free_object(zend_object *object) zval_ptr_dtor(&intern->options); zval_ptr_dtor(&intern->subprotocols); zval_ptr_dtor(&intern->on_open); + zval_ptr_dtor(&intern->on_handshake); zval_ptr_dtor(&intern->on_message); zval_ptr_dtor(&intern->on_close); zval_ptr_dtor(&intern->on_error); @@ -158,6 +161,17 @@ PHP_METHOD(WebSocket_Server, subprotocols) ZVAL_COPY_VALUE(&intern->subprotocols, &normalized); } +PHP_METHOD(WebSocket_Server, onHandshake) +{ + zval *handler; + + ZEND_PARSE_PARAMETERS_START(1, 1) + Z_PARAM_OBJECT_OF_CLASS(handler, zend_ce_closure) + ZEND_PARSE_PARAMETERS_END(); + + websocket_server_store_closure(&Z_WEBSOCKET_SERVER_P(ZEND_THIS)->on_handshake, handler); +} + PHP_METHOD(WebSocket_ServerOptions, __construct) { zend_long max_message_size = WEBSOCKET_DEFAULT_MAX_MESSAGE_SIZE; @@ -252,6 +266,41 @@ PHP_METHOD(WebSocket_Server, onOpen) intern->on_open_cache_initialized = function != NULL; } +PHP_METHOD(WebSocket_Request, header) +{ + zend_string *name; + zend_string *lower_name; + zval *headers; + zval *value; + + ZEND_PARSE_PARAMETERS_START(1, 1) + Z_PARAM_STR(name) + ZEND_PARSE_PARAMETERS_END(); + + lower_name = zend_string_tolower(name); + headers = zend_read_property(websocket_request_ce, Z_OBJ_P(ZEND_THIS), "headers", strlen("headers"), 0, NULL); + value = Z_TYPE_P(headers) == IS_ARRAY ? zend_hash_find(Z_ARRVAL_P(headers), lower_name) : NULL; + zend_string_release(lower_name); + + if (!value) { + RETURN_NULL(); + } + + RETURN_STR_COPY(Z_STR_P(value)); +} + +static bool websocket_header_value_is_valid(zend_string *value) +{ + return !memchr(ZSTR_VAL(value), '\r', ZSTR_LEN(value)) && !memchr(ZSTR_VAL(value), '\n', ZSTR_LEN(value)); +} + +static void websocket_handshake_response_set_properties(zval *object, zend_long status, zval *headers, zend_string *body) +{ + zend_update_property_long(websocket_handshake_response_ce, Z_OBJ_P(object), "status", strlen("status"), status); + zend_update_property(websocket_handshake_response_ce, Z_OBJ_P(object), "headers", strlen("headers"), headers); + zend_update_property_str(websocket_handshake_response_ce, Z_OBJ_P(object), "body", strlen("body"), body); +} + PHP_METHOD(WebSocket_Server, onMessage) { zval *handler; @@ -285,6 +334,87 @@ PHP_METHOD(WebSocket_Server, onError) websocket_server_store_closure(&Z_WEBSOCKET_SERVER_P(ZEND_THIS)->on_error, handler); } +PHP_METHOD(WebSocket_HandshakeResponse, __construct) +{ + zend_long status = 403; + zval *headers = NULL; + zend_string *body = ZSTR_EMPTY_ALLOC(); + zval normalized; + zend_string *name; + zval *value; + + ZEND_PARSE_PARAMETERS_START(0, 3) + Z_PARAM_OPTIONAL + Z_PARAM_LONG(status) + Z_PARAM_ARRAY(headers) + Z_PARAM_STR(body) + ZEND_PARSE_PARAMETERS_END(); + + if (status < 100 || status > 599 || status == 101) { + zend_argument_value_error(1, "must be a valid non-101 HTTP status code"); + RETURN_THROWS(); + } + + array_init(&normalized); + if (headers) { + ZEND_HASH_FOREACH_STR_KEY_VAL(Z_ARRVAL_P(headers), name, value) { + zval header_value; + + if (!name || !websocket_http_validate_subprotocol_token(ZSTR_VAL(name), ZSTR_LEN(name))) { + zval_ptr_dtor(&normalized); + zend_argument_value_error(2, "must contain valid HTTP header names"); + RETURN_THROWS(); + } + if (Z_TYPE_P(value) != IS_STRING) { + zval_ptr_dtor(&normalized); + zend_argument_type_error(2, "must contain string header values, %s given", websocket_zval_value_name(value)); + RETURN_THROWS(); + } + if (!websocket_header_value_is_valid(Z_STR_P(value))) { + zval_ptr_dtor(&normalized); + zend_argument_value_error(2, "must contain HTTP header values without CR or LF"); + RETURN_THROWS(); + } + + ZVAL_STR_COPY(&header_value, Z_STR_P(value)); + zend_hash_add_new(Z_ARRVAL(normalized), name, &header_value); + } ZEND_HASH_FOREACH_END(); + } + + websocket_handshake_response_set_properties(ZEND_THIS, status, &normalized, body); + zval_ptr_dtor(&normalized); +} + +PHP_METHOD(WebSocket_HandshakeException, __construct) +{ + zval *response = NULL; + zval default_response; + bool has_default_response = false; + + ZEND_PARSE_PARAMETERS_START(0, 1) + Z_PARAM_OPTIONAL + Z_PARAM_OBJECT_OF_CLASS_OR_NULL(response, websocket_handshake_response_ce) + ZEND_PARSE_PARAMETERS_END(); + + if (!response) { + zval headers; + + object_init_ex(&default_response, websocket_handshake_response_ce); + array_init(&headers); + websocket_handshake_response_set_properties(&default_response, 403, &headers, ZSTR_EMPTY_ALLOC()); + zval_ptr_dtor(&headers); + response = &default_response; + has_default_response = true; + } + + zend_update_property(websocket_handshake_exception_ce, Z_OBJ_P(ZEND_THIS), "response", strlen("response"), response); + zend_update_property_string(zend_ce_exception, Z_OBJ_P(ZEND_THIS), "message", strlen("message"), "WebSocket handshake rejected"); + + if (has_default_response) { + zval_ptr_dtor(&default_response); + } +} + PHP_METHOD(WebSocket_Server, run) { websocket_server_object *intern = Z_WEBSOCKET_SERVER_P(ZEND_THIS); @@ -345,6 +475,9 @@ void websocket_register_server_class(void) { websocket_server_ce = register_class_WebSocket_Server(); websocket_server_options_ce = register_class_WebSocket_ServerOptions(); + websocket_request_ce = register_class_WebSocket_Request(); + websocket_handshake_response_ce = register_class_WebSocket_HandshakeResponse(); + websocket_handshake_exception_ce = register_class_WebSocket_HandshakeException(zend_ce_exception); websocket_server_ce->create_object = websocket_server_create_object; memcpy(&websocket_server_handlers, zend_get_std_object_handlers(), sizeof(zend_object_handlers)); diff --git a/websocket_server_runtime.c b/websocket_server_runtime.c index 61c982b..9009eb2 100644 --- a/websocket_server_runtime.c +++ b/websocket_server_runtime.c @@ -10,6 +10,8 @@ #include "php.h" #include "php_websocket.h" #include "php_websocket_compat.h" +#include "Zend/zend_exceptions.h" +#include "Zend/zend_smart_str.h" #include #include @@ -41,6 +43,7 @@ typedef struct _websocket_server_frame { } websocket_server_frame; static bool websocket_server_close_with_code(websocket_connection_object *connection_obj, zend_long code, const char *reason); +static bool websocket_server_send_bytes(int fd, const char *buffer, size_t len); static uint64_t websocket_server_handshake_timeout_usec(websocket_server_object *intern); static uint64_t websocket_server_idle_timeout_usec(websocket_server_object *intern); @@ -56,6 +59,12 @@ static const char websocket_service_unavailable_response[] = "Content-Length: 0\r\n" "\r\n"; +static const char websocket_forbidden_response[] = + "HTTP/1.1 403 Forbidden\r\n" + "Connection: close\r\n" + "Content-Length: 0\r\n" + "\r\n"; + static uint64_t websocket_server_now_usec(void) { #ifdef CLOCK_MONOTONIC @@ -155,6 +164,274 @@ static bool websocket_server_call_open_handler(websocket_server_object *intern, return !EG(exception); } +static const char *websocket_http_reason_phrase(const zend_long status) +{ + switch (status) { + case 200: + return "OK"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 429: + return "Too Many Requests"; + case 500: + return "Internal Server Error"; + case 503: + return "Service Unavailable"; + default: + return "Rejected"; + } +} + +static bool websocket_server_send_handshake_response(const int fd, zval *response) +{ + zval *status_zv; + zval *headers_zv; + zval *body_zv; + zend_long status; + zend_string *body; + zend_string *name; + zval *value; + bool has_connection = false; + bool has_content_length = false; + smart_str buffer = {0}; + bool ok; + + status_zv = zend_read_property(websocket_handshake_response_ce, Z_OBJ_P(response), "status", strlen("status"), 0, NULL); + headers_zv = zend_read_property(websocket_handshake_response_ce, Z_OBJ_P(response), "headers", strlen("headers"), 0, NULL); + body_zv = zend_read_property(websocket_handshake_response_ce, Z_OBJ_P(response), "body", strlen("body"), 0, NULL); + + status = zval_get_long(status_zv); + body = zval_get_string(body_zv); + + smart_str_append_printf(&buffer, "HTTP/1.1 %ld %s\r\n", (long) status, websocket_http_reason_phrase(status)); + if (Z_TYPE_P(headers_zv) == IS_ARRAY) { + ZEND_HASH_FOREACH_STR_KEY_VAL(Z_ARRVAL_P(headers_zv), name, value) { + if (!name || Z_TYPE_P(value) != IS_STRING) { + continue; + } + if (zend_string_equals_literal_ci(name, "Connection")) { + has_connection = true; + } else if (zend_string_equals_literal_ci(name, "Content-Length")) { + has_content_length = true; + } + smart_str_append(&buffer, name); + smart_str_appendl(&buffer, ": ", 2); + smart_str_append(&buffer, Z_STR_P(value)); + smart_str_appendl(&buffer, "\r\n", 2); + } ZEND_HASH_FOREACH_END(); + } + if (!has_connection) { + smart_str_appendl(&buffer, "Connection: close\r\n", strlen("Connection: close\r\n")); + } + if (!has_content_length) { + smart_str_append_printf(&buffer, "Content-Length: %zu\r\n", ZSTR_LEN(body)); + } + smart_str_appendl(&buffer, "\r\n", 2); + smart_str_append(&buffer, body); + smart_str_0(&buffer); + + ok = buffer.s && websocket_server_send_bytes(fd, ZSTR_VAL(buffer.s), ZSTR_LEN(buffer.s)); + if (buffer.s) { + smart_str_free(&buffer); + } + zend_string_release(body); + + return ok; +} + +static zend_string *websocket_server_lower_header_name(const char *name, const size_t name_len) +{ + zend_string *header = zend_string_init(name, name_len, false); + zend_string *lower = zend_string_tolower(header); + zend_string_release(header); + return lower; +} + +static void websocket_server_add_request_header(zval *headers, const char *name, const size_t name_len, const char *value, const size_t value_len) +{ + zend_string *lower_name = websocket_server_lower_header_name(name, name_len); + zval *existing = zend_hash_find(Z_ARRVAL_P(headers), lower_name); + zval header_value; + + if (existing && Z_TYPE_P(existing) == IS_STRING) { + zend_string *joined = strpprintf(0, "%s, %.*s", Z_STRVAL_P(existing), (int) value_len, value); + ZVAL_STR(&header_value, joined); + zend_hash_update(Z_ARRVAL_P(headers), lower_name, &header_value); + } else { + ZVAL_STRINGL(&header_value, value, value_len); + zend_hash_update(Z_ARRVAL_P(headers), lower_name, &header_value); + } + + zend_string_release(lower_name); +} + +static void websocket_server_http_trim(const char **value, size_t *len) +{ + while (*len > 0 && ((*value)[0] == ' ' || (*value)[0] == '\t')) { + (*value)++; + (*len)--; + } + + while (*len > 0 && ((*value)[*len - 1] == ' ' || (*value)[*len - 1] == '\t')) { + (*len)--; + } +} + +static bool websocket_server_create_request(zval *request, const char *buffer, const size_t bytes_consumed) +{ + const char *header_end = buffer + bytes_consumed - 4; + const char *request_line_end = memchr(buffer, '\r', bytes_consumed); + const char *method_end; + const char *target_start; + const char *target_end; + const char *line; + zval headers; + + if (!request_line_end || request_line_end + 1 >= buffer + bytes_consumed || request_line_end[1] != '\n') { + return false; + } + + method_end = memchr(buffer, ' ', (size_t) (request_line_end - buffer)); + if (!method_end) { + return false; + } + target_start = method_end + 1; + target_end = memchr(target_start, ' ', (size_t) (request_line_end - target_start)); + if (!target_end) { + return false; + } + + object_init_ex(request, websocket_request_ce); + zend_update_property_stringl(websocket_request_ce, Z_OBJ_P(request), "method", strlen("method"), buffer, (size_t) (method_end - buffer)); + zend_update_property_stringl(websocket_request_ce, Z_OBJ_P(request), "target", strlen("target"), target_start, (size_t) (target_end - target_start)); + + array_init(&headers); + line = request_line_end + 2; + while (line < header_end) { + const char *line_end = memchr(line, '\r', (size_t) (header_end - line) + 1); + const char *colon; + const char *name; + const char *value; + size_t line_len; + size_t name_len; + size_t value_len; + + if (!line_end || line_end + 1 >= buffer + bytes_consumed || line_end[1] != '\n') { + zval_ptr_dtor(&headers); + zval_ptr_dtor(request); + ZVAL_UNDEF(request); + return false; + } + + line_len = (size_t) (line_end - line); + if (line_len == 0) { + break; + } + + colon = memchr(line, ':', line_len); + if (!colon) { + zval_ptr_dtor(&headers); + zval_ptr_dtor(request); + ZVAL_UNDEF(request); + return false; + } + + name = line; + name_len = (size_t) (colon - line); + value = colon + 1; + value_len = line_len - name_len - 1; + websocket_server_http_trim(&name, &name_len); + websocket_server_http_trim(&value, &value_len); + websocket_server_add_request_header(&headers, name, name_len, value, value_len); + + line = line_end + 2; + } + + zend_update_property(websocket_request_ce, Z_OBJ_P(request), "headers", strlen("headers"), &headers); + zval_ptr_dtor(&headers); + + return true; +} + +static bool websocket_server_handle_handshake_exception(websocket_connection_object *connection_obj, bool *accepted) +{ + zend_object *exception = EG(exception); + zval exception_zv; + zval *response; + + if (!exception || !instanceof_function(exception->ce, websocket_handshake_exception_ce)) { + return false; + } + + GC_ADDREF(exception); + zend_clear_exception(); + + ZVAL_OBJ(&exception_zv, exception); + response = zend_read_property(websocket_handshake_exception_ce, Z_OBJ(exception_zv), "response", strlen("response"), 0, NULL); + if (Z_TYPE_P(response) == IS_OBJECT && instanceof_function(Z_OBJCE_P(response), websocket_handshake_response_ce)) { + (void) websocket_server_send_handshake_response(connection_obj->fd, response); + } else { + (void) websocket_server_send_bytes(connection_obj->fd, websocket_forbidden_response, sizeof(websocket_forbidden_response) - 1); + } + + connection_obj->open = false; + *accepted = false; + zval_ptr_dtor(&exception_zv); + + return true; +} + +static bool websocket_server_call_handshake_handler(websocket_server_object *intern, zval *request, websocket_connection_object *connection_obj, bool *accepted) +{ + zval retval; + zval params[1]; + + *accepted = true; + + if (Z_ISUNDEF(intern->on_handshake)) { + return true; + } + + ZVAL_COPY(¶ms[0], request); + ZVAL_UNDEF(&retval); + if (call_user_function(EG(function_table), NULL, &intern->on_handshake, &retval, 1, params) == FAILURE) { + zval_ptr_dtor(¶ms[0]); + zend_throw_error(NULL, "Failed to call WebSocket server handshake handler"); + return false; + } + zval_ptr_dtor(¶ms[0]); + + if (EG(exception)) { + if (websocket_server_handle_handshake_exception(connection_obj, accepted)) { + if (!Z_ISUNDEF(retval)) { + zval_ptr_dtor(&retval); + } + return true; + } + if (!Z_ISUNDEF(retval)) { + zval_ptr_dtor(&retval); + } + return false; + } + + if (!Z_ISUNDEF(retval) && Z_TYPE(retval) != IS_NULL) { + zend_type_error("WebSocket handshake handler must return void, %s returned", websocket_zval_value_name(&retval)); + zval_ptr_dtor(&retval); + return false; + } + + if (!Z_ISUNDEF(retval)) { + zval_ptr_dtor(&retval); + } + return true; +} + static int websocket_server_create_listener(websocket_server_object *intern) { struct addrinfo hints; @@ -916,6 +1193,40 @@ static bool websocket_server_process_handshake(websocket_server_object *intern, return true; } + if (!Z_ISUNDEF(intern->on_handshake)) { + zval request; + bool accepted = true; + + ZVAL_UNDEF(&request); + if (!websocket_server_create_request(&request, connection_obj->read_buffer, bytes_consumed)) { + (void) websocket_server_send_bytes(connection_obj->fd, websocket_bad_request_response, sizeof(websocket_bad_request_response) - 1); + connection_obj->open = false; + zend_string_release(accept_key); + if (selected_subprotocol) { + zend_string_release(selected_subprotocol); + } + return true; + } + + if (!websocket_server_call_handshake_handler(intern, &request, connection_obj, &accepted)) { + zval_ptr_dtor(&request); + zend_string_release(accept_key); + if (selected_subprotocol) { + zend_string_release(selected_subprotocol); + } + return false; + } + zval_ptr_dtor(&request); + + if (!accepted) { + zend_string_release(accept_key); + if (selected_subprotocol) { + zend_string_release(selected_subprotocol); + } + return true; + } + } + if (!websocket_server_finish_upgrade(intern, connection, connection_obj, accept_key, selected_subprotocol, bytes_consumed)) { zend_string_release(accept_key); if (selected_subprotocol) {