1515 */
1616
1717import {
18+ ActionRunOptions ,
1819 GenkitError ,
1920 StreamingCallback ,
2021 defineAction ,
@@ -42,6 +43,8 @@ import {
4243 GenerateResponseChunkSchema ,
4344 GenerateResponseSchema ,
4445 MessageData ,
46+ ModelMiddlewareArgument ,
47+ ModelMiddlewareWithOptions ,
4548 resolveModel ,
4649 type GenerateActionOptions ,
4750 type GenerateActionOutputConfig ,
@@ -85,7 +88,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
8588 outputSchema : GenerateResponseSchema ,
8689 streamSchema : GenerateResponseChunkSchema ,
8790 } ,
88- async ( request , { streamingRequested, sendChunk } ) => {
91+ async ( request , { streamingRequested, sendChunk, context } ) => {
8992 const generateFn = (
9093 sendChunk ?: StreamingCallback < GenerateResponseChunk >
9194 ) =>
@@ -96,6 +99,7 @@ export function defineGenerateAction(registry: Registry): GenerateAction {
9699 // Generate util action does not support middleware. Maybe when we add named/registered middleware....
97100 middleware : [ ] ,
98101 streamingCallback : sendChunk ,
102+ context,
99103 } ) ;
100104 return streamingRequested
101105 ? generateFn ( ( c : GenerateResponseChunk ) =>
@@ -113,18 +117,18 @@ export async function generateHelper(
113117 registry : Registry ,
114118 options : {
115119 rawRequest : GenerateActionOptions ;
116- middleware ?: ModelMiddleware [ ] ;
120+ middleware ?: ModelMiddlewareArgument [ ] ;
117121 currentTurn ?: number ;
118122 messageIndex ?: number ;
119123 abortSignal ?: AbortSignal ;
120124 streamingCallback ?: StreamingCallback < GenerateResponseChunk > ;
125+ context ?: Record < string , any > ;
121126 }
122127) : Promise < GenerateResponseData > {
123128 const currentTurn = options . currentTurn ?? 0 ;
124129 const messageIndex = options . messageIndex ?? 0 ;
125130 // do tracing
126131 return await runInNewSpan (
127- registry ,
128132 {
129133 metadata : {
130134 name : options . rawRequest . stepName || 'generate' ,
@@ -143,6 +147,7 @@ export async function generateHelper(
143147 messageIndex,
144148 abortSignal : options . abortSignal ,
145149 streamingCallback : options . streamingCallback ,
150+ context : options . context ,
146151 } ) ;
147152 metadata . output = JSON . stringify ( output ) ;
148153 return output ;
@@ -247,13 +252,15 @@ async function generate(
247252 messageIndex,
248253 abortSignal,
249254 streamingCallback,
255+ context,
250256 } : {
251257 rawRequest : GenerateActionOptions ;
252- middleware : ModelMiddleware [ ] | undefined ;
258+ middleware : ModelMiddlewareArgument [ ] | undefined ;
253259 currentTurn : number ;
254260 messageIndex : number ;
255261 abortSignal ?: AbortSignal ;
256262 streamingCallback ?: StreamingCallback < GenerateResponseChunk > ;
263+ context ?: Record < string , any > ;
257264 }
258265) : Promise < GenerateResponseData > {
259266 const { model, tools, resources, format } = await resolveParameters (
@@ -320,29 +327,41 @@ async function generate(
320327 }
321328
322329 var response : GenerateResponse ;
330+ const sendChunk =
331+ streamingCallback &&
332+ ( ( ( chunk : GenerateResponseChunkData ) =>
333+ streamingCallback &&
334+ streamingCallback ( makeChunk ( 'model' , chunk ) ) ) as any ) ;
323335 const dispatch = async (
324336 index : number ,
325- req : z . infer < typeof GenerateRequestSchema >
337+ req : z . infer < typeof GenerateRequestSchema > ,
338+ actionOpts : ActionRunOptions < any >
326339 ) => {
327340 if ( ! middleware || index === middleware . length ) {
328341 // end of the chain, call the original model action
329- return await model ( req , {
330- abortSignal,
331- onChunk :
332- streamingCallback &&
333- ( ( ( chunk : GenerateResponseChunkData ) =>
334- streamingCallback &&
335- streamingCallback ( makeChunk ( 'model' , chunk ) ) ) as any ) ,
336- } ) ;
342+ return await model ( req , actionOpts ) ;
337343 }
338344
339345 const currentMiddleware = middleware [ index ] ;
340- return currentMiddleware ( req , async ( modifiedReq ) =>
341- dispatch ( index + 1 , modifiedReq || req )
342- ) ;
346+ if ( currentMiddleware . length === 3 ) {
347+ return ( currentMiddleware as ModelMiddlewareWithOptions ) (
348+ req ,
349+ actionOpts ,
350+ async ( modifiedReq , opts ) =>
351+ dispatch ( index + 1 , modifiedReq || req , opts || actionOpts )
352+ ) ;
353+ } else {
354+ return ( currentMiddleware as ModelMiddleware ) ( req , async ( modifiedReq ) =>
355+ dispatch ( index + 1 , modifiedReq || req , actionOpts )
356+ ) ;
357+ }
343358 } ;
344359
345- const modelResponse = await dispatch ( 0 , request ) ;
360+ const modelResponse = await dispatch ( 0 , request , {
361+ abortSignal,
362+ context,
363+ onChunk : sendChunk ,
364+ } ) ;
346365
347366 if ( model . __action . actionType === 'background-model' ) {
348367 response = new GenerateResponse (
0 commit comments