Skip to content

Commit 9dcde54

Browse files
authored
feat(js/ai): added support for model middleware that can manipulate the stream and context (#3903)
1 parent 0f5ca06 commit 9dcde54

File tree

5 files changed

+210
-22
lines changed

5 files changed

+210
-22
lines changed

js/ai/src/generate.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ import {
5252
type GenerationCommonConfigSchema,
5353
type MessageData,
5454
type ModelArgument,
55-
type ModelMiddleware,
55+
type ModelMiddlewareArgument,
5656
type Part,
5757
type ToolRequestPart,
5858
type ToolResponsePart,
@@ -171,7 +171,7 @@ export interface GenerateOptions<
171171
*/
172172
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
173173
/** Middleware to be used with this model call. */
174-
use?: ModelMiddleware[];
174+
use?: ModelMiddlewareArgument[];
175175
/** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */
176176
context?: ActionContext;
177177
/** Abort signal for the generate request. */

js/ai/src/generate/action.ts

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616

1717
import {
18+
ActionRunOptions,
1819
GenkitError,
1920
StreamingCallback,
2021
defineAction,
@@ -42,6 +43,8 @@ import {
4243
GenerateResponseChunkSchema,
4344
GenerateResponseSchema,
4445
MessageData,
46+
ModelMiddlewareArgument,
47+
ModelMiddlewareWithOptions,
4548
resolveModel,
4649
type GenerateActionOptions,
4750
type GenerateActionOutputConfig,
@@ -85,7 +88,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
8588
outputSchema: GenerateResponseSchema,
8689
streamSchema: GenerateResponseChunkSchema,
8790
},
88-
async (request, { streamingRequested, sendChunk }) => {
91+
async (request, { streamingRequested, sendChunk, context }) => {
8992
const generateFn = (
9093
sendChunk?: StreamingCallback<GenerateResponseChunk>
9194
) =>
@@ -96,6 +99,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
9699
// Generate util action does not support middleware. Maybe when we add named/registered middleware....
97100
middleware: [],
98101
streamingCallback: sendChunk,
102+
context,
99103
});
100104
return streamingRequested
101105
? generateFn((c: GenerateResponseChunk) =>
@@ -113,18 +117,18 @@ export async function generateHelper(
113117
registry: Registry,
114118
options: {
115119
rawRequest: GenerateActionOptions;
116-
middleware?: ModelMiddleware[];
120+
middleware?: ModelMiddlewareArgument[];
117121
currentTurn?: number;
118122
messageIndex?: number;
119123
abortSignal?: AbortSignal;
120124
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
125+
context?: Record<string, any>;
121126
}
122127
): Promise<GenerateResponseData> {
123128
const currentTurn = options.currentTurn ?? 0;
124129
const messageIndex = options.messageIndex ?? 0;
125130
// do tracing
126131
return await runInNewSpan(
127-
registry,
128132
{
129133
metadata: {
130134
name: options.rawRequest.stepName || 'generate',
@@ -143,6 +147,7 @@ export async function generateHelper(
143147
messageIndex,
144148
abortSignal: options.abortSignal,
145149
streamingCallback: options.streamingCallback,
150+
context: options.context,
146151
});
147152
metadata.output = JSON.stringify(output);
148153
return output;
@@ -247,13 +252,15 @@ async function generate(
247252
messageIndex,
248253
abortSignal,
249254
streamingCallback,
255+
context,
250256
}: {
251257
rawRequest: GenerateActionOptions;
252-
middleware: ModelMiddleware[] | undefined;
258+
middleware: ModelMiddlewareArgument[] | undefined;
253259
currentTurn: number;
254260
messageIndex: number;
255261
abortSignal?: AbortSignal;
256262
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
263+
context?: Record<string, any>;
257264
}
258265
): Promise<GenerateResponseData> {
259266
const { model, tools, resources, format } = await resolveParameters(
@@ -320,29 +327,41 @@ async function generate(
320327
}
321328

322329
var response: GenerateResponse;
330+
const sendChunk =
331+
streamingCallback &&
332+
(((chunk: GenerateResponseChunkData) =>
333+
streamingCallback &&
334+
streamingCallback(makeChunk('model', chunk))) as any);
323335
const dispatch = async (
324336
index: number,
325-
req: z.infer<typeof GenerateRequestSchema>
337+
req: z.infer<typeof GenerateRequestSchema>,
338+
actionOpts: ActionRunOptions<any>
326339
) => {
327340
if (!middleware || index === middleware.length) {
328341
// end of the chain, call the original model action
329-
return await model(req, {
330-
abortSignal,
331-
onChunk:
332-
streamingCallback &&
333-
(((chunk: GenerateResponseChunkData) =>
334-
streamingCallback &&
335-
streamingCallback(makeChunk('model', chunk))) as any),
336-
});
342+
return await model(req, actionOpts);
337343
}
338344

339345
const currentMiddleware = middleware[index];
340-
return currentMiddleware(req, async (modifiedReq) =>
341-
dispatch(index + 1, modifiedReq || req)
342-
);
346+
if (currentMiddleware.length === 3) {
347+
return (currentMiddleware as ModelMiddlewareWithOptions)(
348+
req,
349+
actionOpts,
350+
async (modifiedReq, opts) =>
351+
dispatch(index + 1, modifiedReq || req, opts || actionOpts)
352+
);
353+
} else {
354+
return (currentMiddleware as ModelMiddleware)(req, async (modifiedReq) =>
355+
dispatch(index + 1, modifiedReq || req, actionOpts)
356+
);
357+
}
343358
};
344359

345-
const modelResponse = await dispatch(0, request);
360+
const modelResponse = await dispatch(0, request, {
361+
abortSignal,
362+
context,
363+
onChunk: sendChunk,
364+
});
346365

347366
if (model.__action.actionType === 'background-model') {
348367
response = new GenerateResponse(

js/ai/src/model.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
ActionFnArg,
1919
BackgroundAction,
2020
GenkitError,
21+
MiddlewareWithOptions,
2122
Operation,
2223
OperationSchema,
2324
action,
@@ -108,6 +109,16 @@ export type ModelMiddleware = SimpleMiddleware<
108109
z.infer<typeof GenerateResponseSchema>
109110
>;
110111

112+
export type ModelMiddlewareWithOptions = MiddlewareWithOptions<
113+
z.infer<typeof GenerateRequestSchema>,
114+
z.infer<typeof GenerateResponseSchema>,
115+
z.infer<typeof GenerateResponseChunkSchema>
116+
>;
117+
118+
export type ModelMiddlewareArgument =
119+
| ModelMiddleware
120+
| ModelMiddlewareWithOptions;
121+
111122
export type DefineModelOptions<
112123
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
113124
> = {
@@ -121,7 +132,7 @@ export type DefineModelOptions<
121132
/** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */
122133
label?: string;
123134
/** Middleware to be used with this model. */
124-
use?: ModelMiddleware[];
135+
use?: ModelMiddlewareArgument[];
125136
};
126137

127138
export function model<CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny>(
@@ -324,11 +335,11 @@ export function backgroundModel<
324335
}
325336

326337
function getModelMiddleware(options: {
327-
use?: ModelMiddleware[];
338+
use?: ModelMiddlewareArgument[];
328339
name: string;
329340
supports?: ModelInfo['supports'];
330341
}) {
331-
const middleware: ModelMiddleware[] = options.use || [];
342+
const middleware: ModelMiddlewareArgument[] = options.use || [];
332343
if (!options?.supports?.context) middleware.push(augmentWithContext());
333344
const constratedSimulator = simulateConstrainedGeneration();
334345
middleware.push((req, next) => {

js/ai/tests/generate/generate_test.ts

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
defineModel,
3131
type ModelAction,
3232
type ModelMiddleware,
33+
type ModelMiddlewareWithOptions,
3334
} from '../../src/model.js';
3435
import { defineResource } from '../../src/resource.js';
3536
import { defineTool } from '../../src/tool.js';
@@ -804,4 +805,159 @@ describe('generate', () => {
804805
},
805806
]);
806807
});
808+
809+
it('middleware can intercept streaming callback', async () => {
810+
const registry = new Registry();
811+
const echoModel = defineModel(
812+
registry,
813+
{
814+
apiVersion: 'v2',
815+
name: 'echoModel',
816+
supports: { tools: true },
817+
},
818+
async (_, { sendChunk }) => {
819+
if (sendChunk) {
820+
sendChunk({ content: [{ text: 'chunk1' }] });
821+
sendChunk({ content: [{ text: 'chunk2' }] });
822+
}
823+
return {
824+
message: {
825+
role: 'model',
826+
content: [{ text: 'done' }],
827+
},
828+
finishReason: 'stop',
829+
};
830+
}
831+
);
832+
833+
const interceptMiddleware: ModelMiddlewareWithOptions = async (
834+
req,
835+
opts,
836+
next
837+
) => {
838+
const originalOnChunk = opts!.onChunk;
839+
return next(req, {
840+
...opts,
841+
onChunk: (chunk) => {
842+
if (originalOnChunk) {
843+
const text = chunk.content?.[0]?.text;
844+
originalOnChunk({
845+
...chunk,
846+
content: [{ text: `intercepted: ${text}` }],
847+
});
848+
}
849+
},
850+
});
851+
};
852+
853+
const { response, stream } = generateStream(registry, {
854+
model: echoModel,
855+
prompt: 'test',
856+
use: [interceptMiddleware],
857+
});
858+
859+
const streamed: any[] = [];
860+
for await (const chunk of stream) {
861+
streamed.push(chunk.content[0].text);
862+
}
863+
864+
assert.deepStrictEqual(streamed, [
865+
'intercepted: chunk1',
866+
'intercepted: chunk2',
867+
]);
868+
await response;
869+
});
870+
871+
it('middleware can modify context', async () => {
872+
const registry = new Registry();
873+
const checkContextModel = defineModel(
874+
registry,
875+
{
876+
apiVersion: 'v2',
877+
name: 'checkContextModel',
878+
supports: { context: true },
879+
},
880+
async (request, { context }) => {
881+
return {
882+
message: {
883+
role: 'model',
884+
content: [{ text: `Context: ${context?.myValue}` }],
885+
},
886+
finishReason: 'stop',
887+
};
888+
}
889+
);
890+
891+
const contextMiddleware: ModelMiddlewareWithOptions = async (
892+
req,
893+
opts,
894+
next
895+
) => {
896+
return next(req, {
897+
...opts,
898+
context: {
899+
...opts?.context,
900+
myValue: 'foo',
901+
},
902+
});
903+
};
904+
905+
const response = await generate(registry, {
906+
model: checkContextModel,
907+
prompt: 'test',
908+
use: [contextMiddleware],
909+
});
910+
911+
assert.strictEqual(response.text, 'Context: foo');
912+
});
913+
914+
it('middleware can chain option modifications', async () => {
915+
const registry = new Registry();
916+
const checkContextModel = defineModel(
917+
registry,
918+
{
919+
apiVersion: 'v2',
920+
name: 'checkContextModel',
921+
supports: { context: true },
922+
},
923+
async (request, { context }) => {
924+
return {
925+
message: {
926+
role: 'model',
927+
content: [{ text: `Context: ${JSON.stringify(context)}` }],
928+
},
929+
finishReason: 'stop',
930+
};
931+
}
932+
);
933+
934+
const middleware1: ModelMiddlewareWithOptions = async (req, opts, next) => {
935+
return next(req, {
936+
...opts,
937+
context: {
938+
...opts?.context,
939+
val: [...(opts?.context?.val ?? []), 'A'],
940+
},
941+
});
942+
};
943+
944+
const middleware2: ModelMiddlewareWithOptions = async (req, opts, next) => {
945+
return next(req, {
946+
...opts,
947+
context: {
948+
...opts?.context,
949+
val: [...(opts?.context?.val ?? []), 'B'],
950+
},
951+
});
952+
};
953+
954+
const response = await generate(registry, {
955+
model: checkContextModel,
956+
prompt: 'test',
957+
use: [middleware1, middleware2],
958+
});
959+
960+
const context = JSON.parse(response.text.substring('Context: '.length));
961+
assert.deepStrictEqual(context.val, ['A', 'B']);
962+
});
807963
});

js/genkit/src/model.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ export {
5858
type ModelArgument,
5959
type ModelInfo,
6060
type ModelMiddleware,
61+
type ModelMiddlewareArgument,
62+
type ModelMiddlewareWithOptions,
6163
type ModelReference,
6264
type ModelRequest,
6365
type ModelResponseChunkData,

0 commit comments

Comments
 (0)