Skip to content

Commit fbc3bf2

Browse files
authored
Add abort signal option to run method and stop callback to wait method (#106)
* Add stop parameter to wait * Add abort signal parameter to run method
1 parent 7e7ceee commit fbc3bf2

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

index.d.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ declare module 'replicate' {
8888
wait?: { interval?: number; max_attempts?: number };
8989
webhook?: string;
9090
webhook_events_filter?: WebhookEventType[];
91+
signal?: AbortSignal;
9192
}
9293
): Promise<object>;
9394

@@ -105,7 +106,8 @@ declare module 'replicate' {
105106
options: {
106107
interval?: number;
107108
max_attempts?: number;
108-
}
109+
},
110+
stop?: (Prediction) => Promise<boolean>
109111
): Promise<Prediction>;
110112

111113
collections: {

index.js

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Replicate {
8585
* @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit
8686
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
8787
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
88+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
8889
* @throws {Error} If the prediction failed
8990
* @returns {Promise<object>} - Resolves with the output of running the model
9091
*/
@@ -116,7 +117,16 @@ class Replicate {
116117
version,
117118
});
118119

119-
prediction = await this.wait(prediction, wait || {});
120+
const { signal } = options;
121+
122+
prediction = await this.wait(prediction, wait || {}, async ({ id }) => {
123+
if (signal && signal.aborted) {
124+
await this.predictions.cancel(id);
125+
return true; // stop polling
126+
}
127+
128+
return false; // continue polling
129+
});
120130

121131
if (prediction.status === 'failed') {
122132
throw new Error(`Prediction failed: ${prediction.error}`);
@@ -150,7 +160,11 @@ class Replicate {
150160
);
151161
}
152162

153-
const { method = 'GET', params = {}, data } = options;
163+
const {
164+
method = 'GET',
165+
params = {},
166+
data,
167+
} = options;
154168

155169
Object.entries(params).forEach(([key, value]) => {
156170
url.searchParams.append(key, value);
@@ -219,11 +233,12 @@ class Replicate {
219233
* @param {object} options - Options
220234
* @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250
221235
* @param {number} [options.max_attempts] - Maximum number of polling attempts. Defaults to no limit
236+
* @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling.
222237
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
223238
* @throws {Error} If the prediction failed
224239
* @returns {Promise<object>} Resolves with the completed prediction object
225240
*/
226-
async wait(prediction, options) {
241+
async wait(prediction, options, stop) {
227242
const { id } = prediction;
228243
if (!id) {
229244
throw new Error('Invalid prediction');
@@ -261,6 +276,9 @@ class Replicate {
261276
/* eslint-disable no-await-in-loop */
262277
await sleep(interval);
263278
updatedPrediction = await this.predictions.get(prediction.id);
279+
if (stop && await stop(updatedPrediction) === true) {
280+
break;
281+
}
264282
/* eslint-enable no-await-in-loop */
265283
}
266284

index.test.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,44 @@ describe('Replicate client', () => {
581581
});
582582
}).rejects.toThrow('Invalid webhook URL');
583583
});
584+
585+
test('Aborts the operation when abort signal is invoked', async () => {
586+
const controller = new AbortController();
587+
const { signal } = controller;
588+
589+
const scope = nock(BASE_URL)
590+
.post('/predictions', (body) => {
591+
controller.abort();
592+
return body;
593+
})
594+
.reply(201, {
595+
id: 'ufawqhfynnddngldkgtslldrkq',
596+
status: 'processing',
597+
})
598+
.persist()
599+
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
600+
.reply(200, {
601+
id: 'ufawqhfynnddngldkgtslldrkq',
602+
status: 'processing',
603+
})
604+
.post('/predictions/ufawqhfynnddngldkgtslldrkq/cancel')
605+
.reply(200, {
606+
id: 'ufawqhfynnddngldkgtslldrkq',
607+
status: 'canceled',
608+
});;
609+
610+
await client.run(
611+
'owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
612+
{
613+
input: { text: 'Hello, world!' },
614+
signal,
615+
}
616+
)
617+
618+
expect(signal.aborted).toBe(true);
619+
620+
scope.done();
621+
});
584622
});
585623

586624
// Continue with tests for other methods

0 commit comments

Comments
 (0)