Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,5 @@ site/
/Llama.Mobile/Resources/Raw
/nuget_pack.bat
/temp
LLama.Web/wwwroot/lib/
LLama.Web/Models/
147 changes: 122 additions & 25 deletions LLama.Examples/Examples/MtmdInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ namespace LLama.Examples.Examples
// It uses the interactive executor to inference.
public class MtmdInteractiveModeExecute
{
private sealed record TemplateMarkers(string? AssistantEndMarker, string? AssistantToUserMarker);

public static async Task Run()
{
string multiModalProj = UserSettings.GetMMProjPath();
string modelPath = UserSettings.GetModelPath();
const int maxTokens = 4096;

string? prompt = await File.ReadAllTextAsync("Assets/chat-with-bob.json");

var parameters = new ModelParams(modelPath);

var mtmdParameters = MtmdContextParams.Default();
Expand All @@ -40,8 +40,8 @@ public static async Task Run()

var ex = new InteractiveExecutor(context, clipModel);
var chatHistory = new ChatHistory();
var isFirstTurn = true;

var templateMarkers = ResolveTemplateMarkers(model);
var antiPrompts = GetEffectiveAntiPrompts(templateMarkers);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
Console.WriteLine("Model: {0}", modelPath);
Expand All @@ -54,6 +54,7 @@ public static async Task Run()
Console.WriteLine("Video inputs are not supported (format would be {{c:/video.mp4}}).");
Console.WriteLine("Commands: /exit (return to main menu) | /clear (reset the chat history and KV cache).");
Console.WriteLine("Press Ctrl+c to return to main menu.");
Console.Write("User: ");

void ResetConversation()
{
Expand All @@ -64,7 +65,7 @@ void ResetConversation()
clipModel.ClearMedia();
chatHistory.Messages.Clear();
ex = new InteractiveExecutor(context, clipModel);
Console.WriteLine("User:");
Console.Write("User: ");
}

var inferenceParams = new InferenceParams
Expand All @@ -74,34 +75,34 @@ void ResetConversation()
Temperature = 0.1f
},

AntiPrompts = new List<string> { "User:" },
MaxTokens = maxTokens
AntiPrompts = antiPrompts,
MaxTokens = maxTokens,
DecodeSpecialTokens = ShouldDecodeSpecialTokens(antiPrompts)

};

do
{
if (!isFirstTurn)
{
Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.WriteLine();
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.WriteLine();

if (prompt == null || prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;
if (prompt == null || prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

if (prompt.Equals("/clear", StringComparison.OrdinalIgnoreCase))
{
ResetConversation();
continue;
}
if (prompt.Equals("/clear", StringComparison.OrdinalIgnoreCase))
{
ResetConversation();
continue;
}
else

if (string.IsNullOrWhiteSpace(prompt))
{
isFirstTurn = false;
Console.Write("User: ");
continue;
}

var userPrompt = prompt ?? string.Empty;
var userPrompt = prompt;

// Evaluate if we have media
//
Expand Down Expand Up @@ -171,7 +172,7 @@ void ResetConversation()
audioList.Add(mediaPath);
}

var embed = clipModel.LoadMedia(mediaPath);
var embed = clipModel.LoadMediaStandalone(mediaPath);
embeds.Add(embed);
}
}
Expand All @@ -183,7 +184,8 @@ void ResetConversation()
Console.Write($"{exception.Message}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Please try again.");
clipModel.ClearMedia();
foreach (var embed in embeds)
embed.Dispose();
break;
}

Expand Down Expand Up @@ -236,8 +238,9 @@ void ResetConversation()
Console.Write(text);
responseBuilder.Append(text);
}
Console.Write(" ");
Console.WriteLine();
chatHistory.AddMessage(AuthorRole.Assistant, responseBuilder.ToString());
Console.Write("User: ");

}
while(true);
Expand All @@ -259,6 +262,45 @@ private static string BuildChatDelta(LLamaWeights model, ChatHistory history, st
return delta;
}

private static List<string> GetEffectiveAntiPrompts(TemplateMarkers templateMarkers)
{
var antiPrompts = new List<string>();
AddMarker(antiPrompts, templateMarkers.AssistantEndMarker);
AddMarker(antiPrompts, templateMarkers.AssistantToUserMarker);

if (antiPrompts.Count == 0)
antiPrompts.Add("User:");

return antiPrompts;
}

private static void AddMarker(List<string> values, string? marker)
{
if (string.IsNullOrWhiteSpace(marker))
return;

if (!values.Contains(marker))
values.Add(marker);

var trimmedMarker = marker.Trim();
if (!string.IsNullOrWhiteSpace(trimmedMarker) && !values.Contains(trimmedMarker))
values.Add(trimmedMarker);
}

private static bool ShouldDecodeSpecialTokens(IReadOnlyList<string> antiPrompts)
{
foreach (var antiPrompt in antiPrompts)
{
if (string.IsNullOrWhiteSpace(antiPrompt))
continue;

if (antiPrompt.Contains('<', StringComparison.Ordinal) || antiPrompt.Contains('>', StringComparison.Ordinal))
return true;
}

return false;
}

private static string FormatChatHistory(LLamaWeights model, ChatHistory history, bool addAssistant)
{
var template = new LLamaTemplate(model.NativeHandle)
Expand All @@ -271,5 +313,60 @@ private static string FormatChatHistory(LLamaWeights model, ChatHistory history,

return LLamaTemplate.Encoding.GetString(template.Apply());
}

private static TemplateMarkers ResolveTemplateMarkers(LLamaWeights model)
{
const string userMarkerA = "__LLAMA_USER_A__";
const string assistantMarkerA = "__LLAMA_ASSISTANT_A__";
const string userMarkerB = "__LLAMA_USER_B__";

try
{
var assistantTemplate = new LLamaTemplate(model.NativeHandle)
{
AddAssistant = false
};
assistantTemplate.Add("user", userMarkerA);
assistantTemplate.Add("assistant", assistantMarkerA);

var assistantRendered = LLamaTemplate.Encoding.GetString(assistantTemplate.Apply());
var assistantIndex = assistantRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal);
if (assistantIndex < 0)
return new TemplateMarkers(null, null);

var assistantEndMarker = assistantRendered[(assistantIndex + assistantMarkerA.Length)..];

var conversationTemplate = new LLamaTemplate(model.NativeHandle)
{
AddAssistant = false
};
conversationTemplate.Add("user", userMarkerA);
conversationTemplate.Add("assistant", assistantMarkerA);
conversationTemplate.Add("user", userMarkerB);

var conversationRendered = LLamaTemplate.Encoding.GetString(conversationTemplate.Apply());
var assistantConversationIndex = conversationRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal);
var userIndex = conversationRendered.IndexOf(userMarkerB, StringComparison.Ordinal);
if (assistantConversationIndex < 0 || userIndex <= assistantConversationIndex)
return new TemplateMarkers(NormalizeMarker(assistantEndMarker), null);

var assistantToUserMarker = conversationRendered.Substring(
assistantConversationIndex + assistantMarkerA.Length,
userIndex - (assistantConversationIndex + assistantMarkerA.Length));

return new TemplateMarkers(
NormalizeMarker(assistantEndMarker),
NormalizeMarker(assistantToUserMarker));
}
catch
{
return new TemplateMarkers(null, null);
}
}

private static string? NormalizeMarker(string? marker)
{
return string.IsNullOrWhiteSpace(marker) ? null : marker;
}
}
}
Loading
Loading