Skip to content

Commit d2cd916

Browse files
authored
Merge pull request #74 from replicate/mattt/prompt-template
Use prompt template to format messages
2 parents 2505c5f + 15fa5b3 commit d2cd916

5 files changed

Lines changed: 238 additions & 24 deletions

File tree

app/api/route.js

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import Replicate from "replicate";
22
import { ReplicateStream, StreamingTextResponse } from "ai";
3-
43
export const runtime = "edge";
54

65
const replicate = new Replicate({
@@ -43,13 +42,12 @@ async function runLlama({
4342
}) {
4443
console.log("running llama");
4544

46-
const [owner, name] = model.split("/");
47-
48-
return await replicate.models.predictions.create(owner, name, {
45+
return await replicate.predictions.create({
46+
model: model,
4947
stream: true,
5048
input: {
5149
prompt: `${prompt}`,
52-
system_prompt: systemPrompt,
50+
prompt_template: "{prompt}",
5351
max_new_tokens: maxTokens,
5452
temperature: temperature,
5553
repetition_penalty: 1,

app/page.js

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import EmptyState from "./components/EmptyState";
88
import { Cog6ToothIcon, CodeBracketIcon } from "@heroicons/react/20/solid";
99
import { useCompletion } from "ai/react";
1010
import { Toaster, toast } from "react-hot-toast";
11+
import { LlamaTemplate } from "../src/prompt_template";
12+
1113
import { countTokens } from "./src/tokenizer.js";
1214

1315
const MODELS = [
@@ -38,6 +40,20 @@ const MODELS = [
3840
},
3941
];
4042

43+
const llamaTemplate = LlamaTemplate();
44+
45+
const generatePrompt = (template, systemPrompt, messages) => {
46+
const chat = messages.map((message) => ({
47+
"role": message.isUser ? "user" : "assistant",
48+
"content": message.text,
49+
}));
50+
51+
return template([{
52+
"role": "system",
53+
"content": systemPrompt,
54+
}, ...chat]);
55+
};
56+
4157
function CTA({ shortenedModelName }) {
4258
if (shortenedModelName == "Llava") {
4359
return (
@@ -141,7 +157,6 @@ export default function HomePage() {
141157

142158
const handleFileUpload = (file) => {
143159
if (file) {
144-
console.log(file);
145160
// determine if file is image or audio
146161
if (
147162
["audio/mpeg", "audio/wav", "audio/ogg"].includes(
@@ -192,16 +207,8 @@ export default function HomePage() {
192207
isUser: true,
193208
});
194209

195-
const generatePrompt = (messages) => {
196-
return messages
197-
.map((message) =>
198-
message.isUser ? `[INST] ${message.text} [/INST]` : `${message.text}`
199-
)
200-
.join("\n");
201-
};
202-
203210
// Generate initial prompt and calculate tokens
204-
let prompt = `${generatePrompt(messageHistory)}\n`;
211+
let prompt = `${generatePrompt(llamaTemplate, systemPrompt, messageHistory)}\n`;
205212
// Check if we exceed max tokens and truncate the message history if so.
206213
while (countTokens(prompt) > MAX_TOKENS) {
207214
if (messageHistory.length < 3) {
@@ -216,7 +223,7 @@ export default function HomePage() {
216223
messageHistory.splice(1, 2);
217224

218225
// Recreate the prompt
219-
prompt = `${SNIP}\n${generatePrompt(messageHistory)}\n`;
226+
prompt = `${SNIP}\n${generatePrompt(llamaTemplate, systemPrompt, messageHistory)}\n`;
220227
}
221228

222229
setMessages(messageHistory);

0 commit comments

Comments
 (0)