diff --git a/src/workerd/api/container.c++ b/src/workerd/api/container.c++ index 24137471ff0..a708d98b4ac 100644 --- a/src/workerd/api/container.c++ +++ b/src/workerd/api/container.c++ @@ -71,6 +71,20 @@ jsg::Promise Container::setInactivityTimeout(jsg::Lock& js, int64_t durati return IoContext::current().awaitIo(js, req.sendIgnoringResult()); } +void Container::setEgressTcp(jsg::Lock& js, kj::String addr, jsg::Ref binding) { + auto& ioctx = IoContext::current(); + auto channel = binding->getSubrequestChannel(ioctx); + + // Get a channel token for RPC usage, the container runtime can use this + // token later to redeem a Fetcher. + auto token = channel->getToken(IoChannelFactory::ChannelTokenUsage::RPC); + + auto req = rpcClient->setEgressTcpRequest(); + req.setAddr(addr); + req.setChannelToken(token); + ioctx.addTask(req.sendIgnoringResult()); +} + jsg::Promise Container::monitor(jsg::Lock& js) { JSG_REQUIRE(running, Error, "monitor() cannot be called on a container that is not running."); diff --git a/src/workerd/api/container.h b/src/workerd/api/container.h index 1cbacce4027..6785affe501 100644 --- a/src/workerd/api/container.h +++ b/src/workerd/api/container.h @@ -62,6 +62,7 @@ class Container: public jsg::Object { void signal(jsg::Lock& js, int signo); jsg::Ref getTcpPort(jsg::Lock& js, int port); jsg::Promise setInactivityTimeout(jsg::Lock& js, int64_t durationMs); + void setEgressTcp(jsg::Lock& js, kj::String addr, jsg::Ref binding); // TODO(containers): listenTcp() @@ -73,6 +74,10 @@ class Container: public jsg::Object { JSG_METHOD(signal); JSG_METHOD(getTcpPort); JSG_METHOD(setInactivityTimeout); + + if (flags.getWorkerdExperimental()) { + JSG_METHOD(setEgressTcp); + } } void visitForMemoryInfo(jsg::MemoryTracker& tracker) const { diff --git a/src/workerd/io/container.capnp b/src/workerd/io/container.capnp index f79e85defb2..f3a4c18c8ca 100644 --- a/src/workerd/io/container.capnp +++ b/src/workerd/io/container.capnp @@ -110,4 +110,8 @@ interface Container @0x9aaceefc06523bca { # Note that if there is an open connection to the container, the runtime must not shutdown the container. # If there is no activity timeout duration configured and no container connection, it's up to the runtime # to decide when to signal the container to exit. + + setEgressTcp @8 (addr :Text, channelToken :Data); + # Configures egress TCP routing for the container. When the container attempts to connect to the + # specified address, the connection should be routed back to the Workers runtime using the channel token. } diff --git a/src/workerd/server/BUILD.bazel b/src/workerd/server/BUILD.bazel index 63eaea3511b..a081b4bc78e 100644 --- a/src/workerd/server/BUILD.bazel +++ b/src/workerd/server/BUILD.bazel @@ -193,19 +193,29 @@ wd_cc_library( ], ) +wd_cc_library( + name = "channel-token", + srcs = ["channel-token.c++"], + hdrs = ["channel-token.h"], + deps = [ + ":channel-token_capnp", + "//src/workerd/io", + "//src/workerd/util:entropy", + ], +) + wd_cc_library( name = "server", srcs = [ - "channel-token.c++", "server.c++", ], hdrs = [ - "channel-token.h", "server.h", ], deps = [ ":actor-id-impl", ":alarm-scheduler", + ":channel-token", ":channel-token_capnp", ":container-client", ":facet-tree-index", @@ -257,7 +267,9 @@ wd_cc_library( hdrs = ["container-client.h"], visibility = ["//visibility:public"], deps = [ + ":channel-token", ":docker-api_capnp", + "//src/workerd/io", "//src/workerd/io:container_capnp", "//src/workerd/jsg", "@capnp-cpp//src/capnp/compat:http-over-capnp", diff --git a/src/workerd/server/container-client.c++ b/src/workerd/server/container-client.c++ index 52c5e24fa3d..f22e2cc61d9 100644 --- a/src/workerd/server/container-client.c++ +++ b/src/workerd/server/container-client.c++ @@ -108,7 +108,8 @@ ContainerClient::ContainerClient(capnp::ByteStreamFactory& byteStreamFactory, kj::String containerName, kj::String imageName, kj::TaskSet& waitUntilTasks, - kj::Function cleanupCallback) + kj::Function cleanupCallback, + ChannelTokenHandler& channelTokenHandler) : byteStreamFactory(byteStreamFactory), timer(timer), network(network), @@ -116,7 +117,8 @@ ContainerClient::ContainerClient(capnp::ByteStreamFactory& byteStreamFactory, containerName(kj::encodeUriComponent(kj::mv(containerName))), imageName(kj::mv(imageName)), waitUntilTasks(waitUntilTasks), - cleanupCallback(kj::mv(cleanupCallback)) {} + cleanupCallback(kj::mv(cleanupCallback)), + channelTokenHandler(channelTokenHandler) {} ContainerClient::~ContainerClient() noexcept(false) { // Call the cleanup callback to remove this client from the ActorNamespace map @@ -466,6 +468,27 @@ kj::Promise ContainerClient::listenTcp(ListenTcpContext context) { KJ_UNIMPLEMENTED("listenTcp not implemented for Docker containers - use port mapping instead"); } +kj::Promise ContainerClient::setEgressTcp(SetEgressTcpContext context) { + auto params = context.getParams(); + auto addr = kj::str(params.getAddr()); + auto tokenBytes = params.getChannelToken(); + + // Redeem the channel token to get a SubrequestChannel + auto subrequestChannel = channelTokenHandler.decodeSubrequestChannelToken( + workerd::IoChannelFactory::ChannelTokenUsage::RPC, tokenBytes); + + // Store the mapping + egressMappings.upsert(kj::mv(addr), kj::mv(subrequestChannel), + [](auto& existing, auto&& newValue) { existing = kj::mv(newValue); }); + + // TODO: At some point we need to figure out how to make it so + // in local development we are able to actually map to an egress mapping. + // For now, just fake it for testing purposes the decoding of the + // subrequest channel token. + + co_return; +} + kj::Own ContainerClient::addRef() { return kj::addRef(*this); } diff --git a/src/workerd/server/container-client.h b/src/workerd/server/container-client.h index 456c8911eb5..a80d0065cfa 100644 --- a/src/workerd/server/container-client.h +++ b/src/workerd/server/container-client.h @@ -5,6 +5,8 @@ #pragma once #include +#include +#include #include #include @@ -33,7 +35,8 @@ class ContainerClient final: public rpc::Container::Server, public kj::Refcounte kj::String containerName, kj::String imageName, kj::TaskSet& waitUntilTasks, - kj::Function cleanupCallback); + kj::Function cleanupCallback, + ChannelTokenHandler& channelTokenHandler); ~ContainerClient() noexcept(false); @@ -46,6 +49,7 @@ class ContainerClient final: public rpc::Container::Server, public kj::Refcounte kj::Promise getTcpPort(GetTcpPortContext context) override; kj::Promise listenTcp(ListenTcpContext context) override; kj::Promise setInactivityTimeout(SetInactivityTimeoutContext context) override; + kj::Promise setEgressTcp(SetEgressTcpContext context) override; kj::Own addRef(); @@ -93,6 +97,12 @@ class ContainerClient final: public rpc::Container::Server, public kj::Refcounte // Cleanup callback to remove from ActorNamespace map when destroyed kj::Function cleanupCallback; + + // For redeeming channel tokens received via setEgressTcp + ChannelTokenHandler& channelTokenHandler; + + // Egress TCP mappings: address -> SubrequestChannel + kj::HashMap> egressMappings; }; } // namespace workerd::server diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 2ff0cb33639..29849e56eed 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -1929,9 +1929,9 @@ class Server::WorkerService final: public Service, } auto actorClass = kj::refcounted(*this, entry.key, Frankenvalue()); - auto ns = - kj::heap(kj::mv(actorClass), entry.value, threadContext.getUnsafeTimer(), - threadContext.getByteStreamFactory(), network, dockerPath, waitUntilTasks); + auto ns = kj::heap(kj::mv(actorClass), entry.value, + threadContext.getUnsafeTimer(), threadContext.getByteStreamFactory(), channelTokenHandler, + network, dockerPath, waitUntilTasks); actorNamespaces.insert(entry.key, kj::mv(ns)); } } @@ -2195,6 +2195,7 @@ class Server::WorkerService final: public Service, const ActorConfig& config, kj::Timer& timer, capnp::ByteStreamFactory& byteStreamFactory, + ChannelTokenHandler& channelTokenHandler, kj::Network& dockerNetwork, kj::Maybe dockerPath, kj::TaskSet& waitUntilTasks) @@ -2202,6 +2203,7 @@ class Server::WorkerService final: public Service, config(config), timer(timer), byteStreamFactory(byteStreamFactory), + channelTokenHandler(channelTokenHandler), dockerNetwork(dockerNetwork), dockerPath(dockerPath), waitUntilTasks(waitUntilTasks) {} @@ -2856,7 +2858,7 @@ class Server::WorkerService final: public Service, auto client = kj::refcounted(byteStreamFactory, timer, dockerNetwork, kj::str(dockerPathRef), kj::str(containerId), kj::str(imageName), waitUntilTasks, - kj::mv(cleanupCallback)); + kj::mv(cleanupCallback), channelTokenHandler); // Store raw pointer in map (does not own) containerClients.insert(kj::str(containerId), client.get()); @@ -2901,6 +2903,7 @@ class Server::WorkerService final: public Service, kj::Maybe> cleanupTask; kj::Timer& timer; capnp::ByteStreamFactory& byteStreamFactory; + ChannelTokenHandler& channelTokenHandler; kj::Network& dockerNetwork; kj::Maybe dockerPath; kj::TaskSet& waitUntilTasks; diff --git a/src/workerd/server/tests/container-client/container-client.wd-test b/src/workerd/server/tests/container-client/container-client.wd-test index 06e5b2435c3..7dd4321d2d5 100644 --- a/src/workerd/server/tests/container-client/container-client.wd-test +++ b/src/workerd/server/tests/container-client/container-client.wd-test @@ -8,7 +8,7 @@ const unitTests :Workerd.Config = ( modules = [ (name = "worker", esModule = embed "test.js") ], - compatibilityFlags = ["nodejs_compat", "experimental"], + compatibilityFlags = ["enable_ctx_exports", "nodejs_compat", "experimental"], containerEngine = (localDocker = (socketPath = "unix:/var/run/docker.sock")), durableObjectNamespaces = [ ( className = "DurableObjectExample", diff --git a/src/workerd/server/tests/container-client/test.js b/src/workerd/server/tests/container-client/test.js index d1659d0f3a1..6517d937f59 100644 --- a/src/workerd/server/tests/container-client/test.js +++ b/src/workerd/server/tests/container-client/test.js @@ -1,4 +1,4 @@ -import { DurableObject } from 'cloudflare:workers'; +import { DurableObject, WorkerEntrypoint } from 'cloudflare:workers'; import assert from 'node:assert'; import { scheduler } from 'node:timers/promises'; @@ -66,7 +66,7 @@ export class DurableObjectExample extends DurableObject { { let resp; // The retry count here is arbitrary. Can increase it if necessary. - const maxRetries = 6; + const maxRetries = 15; for (let i = 1; i <= maxRetries; i++) { try { resp = await container @@ -260,6 +260,33 @@ export class DurableObjectExample extends DurableObject { getStatus() { return this.ctx.container.running; } + + async testSetEgressTcp() { + const container = this.ctx.container; + if (container.running) { + let monitor = container.monitor().catch((_err) => {}); + await container.destroy(); + await monitor; + } + assert.strictEqual(container.running, false); + + // Start container + container.start(); + assert.strictEqual(container.running, true); + + // Set up egress TCP mapping to route requests to the binding + // This registers the binding's channel token with the container runtime + container.setEgressTcp( + '10.0.0.1:9999', + this.ctx.exports.TestService({ props: {} }) + ); + } +} + +export class TestService extends WorkerEntrypoint { + fetch() { + return new Response('you have hit TestService'); + } } export class DurableObjectExample2 extends DurableObjectExample {} @@ -394,3 +421,12 @@ export const testSetInactivityTimeout = { } }, }; + +// Test setEgressTcp functionality - registers a binding's channel token with the container +export const testSetEgressTcp = { + async test(_ctrl, env) { + const id = env.MY_CONTAINER.idFromName('testSetEgressTcp'); + const stub = env.MY_CONTAINER.get(id); + await stub.testSetEgressTcp(); + }, +}; diff --git a/types/generated-snapshot/experimental/index.d.ts b/types/generated-snapshot/experimental/index.d.ts index 4f42ea0c98c..c823e6cfbcb 100755 --- a/types/generated-snapshot/experimental/index.d.ts +++ b/types/generated-snapshot/experimental/index.d.ts @@ -3829,6 +3829,7 @@ interface Container { signal(signo: number): void; getTcpPort(port: number): Fetcher; setInactivityTimeout(durationMs: number | bigint): Promise; + setEgressTcp(addr: string, binding: Fetcher): void; } interface ContainerStartupOptions { entrypoint?: string[]; diff --git a/types/generated-snapshot/experimental/index.ts b/types/generated-snapshot/experimental/index.ts index 53215321c66..b6180076e82 100755 --- a/types/generated-snapshot/experimental/index.ts +++ b/types/generated-snapshot/experimental/index.ts @@ -3840,6 +3840,7 @@ export interface Container { signal(signo: number): void; getTcpPort(port: number): Fetcher; setInactivityTimeout(durationMs: number | bigint): Promise; + setEgressTcp(addr: string, binding: Fetcher): void; } export interface ContainerStartupOptions { entrypoint?: string[];