-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathProgram.cs
More file actions
96 lines (79 loc) · 2.88 KB
/
Program.cs
File metadata and controls
96 lines (79 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
using Microsoft.ML.OnnxRuntimeGenAI;
using System.ComponentModel;
using System.Diagnostics;
using System.Reflection;
using System.Text;
namespace MyNewCopilot_PC;
internal class Program
{
// Download from https://huggingface.co/microsoft/
// Phi-3.5-mini-instruct-onnx/tree/main
private static readonly string modelDir =
@"C:\Code\Phi-3.5-mini-instruct-onnx\cpu_and_mobile\cpu-int4-awq-block-128-acc-level-4";
static async Task Main(string[] args)
{
Console.WriteLine($"Loading model: {modelDir}");
var sw = Stopwatch.StartNew();
using var model = new Model(modelDir);
using var tokenizer = new Tokenizer(model);
sw.Stop();
Console.WriteLine($"Model loading took {sw.ElapsedMilliseconds} ms");
await ListIcao("KDFW", model, tokenizer);
await ListIcao("KCLE", model, tokenizer);
}
private static async Task<bool> ListIcao(string icaoName, Model model, Tokenizer tokenizer)
{
var systemPrompt = "You are a helpful assistant.";
var icao = Icao.Instance.GetIcao(icaoName);
var userPrompt = $"Tell me about {icao.City} {icao.State} . Be brief.";
var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{userPrompt}<|end|><|assistant|>";
await WriteResults(prompt, model, tokenizer);
Console.WriteLine();
Console.WriteLine("");
userPrompt = $"Tell me about ICAO {icao.ICAO}. Be concise. ";
prompt = $"<|system|>{systemPrompt}<|end|><|user|>{userPrompt}<|end|><|assistant|>";
await WriteResults(prompt, model, tokenizer);
Console.WriteLine();
return true;
}
private static async Task<bool> WriteResults(string prompt, Model model, Tokenizer tokenizer)
{
await foreach (var part in InferStreaming(prompt, model, tokenizer))
{
Console.Write(part);
}
return true;
}
public static async IAsyncEnumerable<string> InferStreaming(string prompt,
Model model, Tokenizer tokenizer)
{
using var generatorParams = new GeneratorParams(model);
using var sequences = tokenizer.Encode(prompt);
generatorParams.SetSearchOption("max_length", 2048);
generatorParams.SetSearchOption("top_p", 0.5);
//generatorParams.SetSearchOption("top_k", 1);
generatorParams.SetSearchOption("temperature", 0.8);
generatorParams.SetInputSequences(sequences);
generatorParams.TryGraphCaptureWithMaxBatchSize(1);
using var tokenizerStream = tokenizer.CreateStream();
using var generator = new Generator(model, generatorParams);
StringBuilder stringBuilder = new();
while (!generator.IsDone())
{
string part;
await Task.Delay(10).ConfigureAwait(false);
generator.ComputeLogits();
generator.GenerateNextToken();
part = tokenizerStream.Decode(generator.GetSequence(0)[^1]);
stringBuilder.Append(part);
if (stringBuilder.ToString().Contains("<|end|>")
|| stringBuilder.ToString().Contains("<|user|>")
|| stringBuilder.ToString().Contains("<|system|>"))
{
break;
}
if (!string.IsNullOrWhiteSpace(part))
yield return part;
}
}
}