Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/silver-dryers-run.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@cloudflare/ai-chat": patch
---

Moved `/get-messages` endpoint handling from a prototype `override onRequest()` method to a constructor wrapper. This ensures the endpoint always works, even when users override `onRequest` without calling `super.onRequest()`.
24 changes: 11 additions & 13 deletions packages/ai-chat/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,17 @@ export class AIChatAgent<
// Forward unhandled messages to consumer's onMessage
return _onMessage(connection, message);
};

const _onRequest = this.onRequest.bind(this);
this.onRequest = async (request: Request) => {
return this._tryCatchChat(async () => {
const url = new URL(request.url);
if (url.pathname.endsWith("/get-messages")) {
return Response.json(this._loadMessagesFromDb());
}
return _onRequest(request);
});
};
}

/**
Expand Down Expand Up @@ -871,19 +882,6 @@ export class AIChatAgent<
.filter((msg): msg is ChatMessage => msg !== null);
}

override async onRequest(request: Request): Promise<Response> {
return this._tryCatchChat(async () => {
const url = new URL(request.url);

if (url.pathname.endsWith("/get-messages")) {
const messages = this._loadMessagesFromDb();
return Response.json(messages);
}

return super.onRequest(request);
});
}

private async _tryCatchChat<T>(fn: () => T | Promise<T>) {
try {
return await fn();
Expand Down
26 changes: 26 additions & 0 deletions packages/ai-chat/src/tests/get-messages-endpoint.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,29 @@ describe("GET /get-messages endpoint", () => {
expect(res.status).toBe(404);
});
});

describe("onRequest override patterns", () => {
it("/get-messages works when user overrides onRequest and calls super", async () => {
const room = crypto.randomUUID();

const { ws } = await connectChatWS(`/agents/agent-with-super-call/${room}`);
await new Promise((r) => setTimeout(r, 50));

const agentStub = await getAgentByName(env.AgentWithSuperCall, room);
const messages: ChatMessage[] = [
{ id: "test-1", role: "user", parts: [{ type: "text", text: "Hello" }] }
];
await agentStub.persistMessages(messages);
ws.close(1000);

const req = new Request(
`http://example.com/agents/agent-with-super-call/${room}/get-messages`
);
const res = await worker.fetch(req, env, createExecutionContext());

expect(res.status).toBe(200);
const returned = (await res.json()) as ChatMessage[];
expect(returned.length).toBe(1);
expect(returned[0].id).toBe("test-1");
});
});
16 changes: 16 additions & 0 deletions packages/ai-chat/src/tests/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type TestToolCallPart = Extract<

export type Env = {
TestChatAgent: DurableObjectNamespace<TestChatAgent>;
AgentWithSuperCall: DurableObjectNamespace<AgentWithSuperCall>;
};

export class TestChatAgent extends AIChatAgent<Env> {
Expand Down Expand Up @@ -320,6 +321,21 @@ export class TestChatAgent extends AIChatAgent<Env> {
}
}

// Test agent that overrides onRequest and calls super.onRequest()
export class AgentWithSuperCall extends AIChatAgent<Env> {
async onRequest(request: Request): Promise<Response> {
const url = new URL(request.url);
if (url.pathname.endsWith("/custom-route")) {
return new Response("custom route");
}
return super.onRequest(request);
}

async onChatMessage() {
return new Response("chat response");
}
}

export default {
async fetch(request: Request, env: Env, _ctx: ExecutionContext) {
const url = new URL(request.url);
Expand Down
8 changes: 8 additions & 0 deletions packages/ai-chat/src/tests/wrangler.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
{
"class_name": "TestChatAgent",
"name": "TestChatAgent"
},
{
"class_name": "AgentWithSuperCall",
"name": "AgentWithSuperCall"
}
]
},
Expand All @@ -21,6 +25,10 @@
{
"new_sqlite_classes": ["TestChatAgent"],
"tag": "v1"
},
{
"new_sqlite_classes": ["AgentWithSuperCall"],
"tag": "v2"
}
]
}
Loading