-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathannotator_ollama.go
More file actions
243 lines (215 loc) · 7.43 KB
/
annotator_ollama.go
File metadata and controls
243 lines (215 loc) · 7.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strings"
"time"
"github.com/avast/retry-go"
api "github.com/ollama/ollama/api"
log "github.com/sirupsen/logrus"
)
var (
OllamaAnnotatorFirstInstructions = `You will be provided
instructions and then a communication message.
Answer any questions that may have been asked about the message.
If asked to process the message, you will use your skills and any examples
or rules provided to edit, select, transform, evaluate or otherwise process
the text strictly according to the directions given. Only make additions or
subtractions from the original text. Do not replace or transform words such
as to modify case unless specifically instructed to.
If asked to evaluate the message numerically, use your skills and any
examples rules or criteria given to calculate a numerical result for the
message.
Here's the criteria:
`
OllamaAnnotatorFinalInstructions = `
Return true or false corresponding to the answer in the 'question' field.
Return the processed text in the 'processed_text' field.
Return any numerical evaluation in the 'processed_number' field.
Provide feedback summarizing your actions or commentary in the
'model_feedback' field.
`
OllamaAnnotatorTimeout = 120
OllamaAnnotatorMaxRetryAttempts = 6
OllamaAnnotatorRetryDelaySeconds = 5
)
type OllamaAnnotator struct {
Annotator
Module
OllamaCommonConfig
// Only provide these fields to future steps.
SelectedFields []string
}
type OllamaAnnotatorResponse struct {
ModelFeedbackText string `json:"model_feedback" ap:"LLMModelFeedbackText"`
ProcessedNumber int `json:"processed_number" ap:"LLMProcessedNumber"`
ProcessedText string `json:"processed_text" ap:"LLMProcessedText"`
YesNoQuestionAnswer bool `json:"question_answer" ap:"LLMYesNoQuestionAnswer"`
}
type OllamaAnnotatorResponseFormat struct {
Type string `json:"type"`
Properties OllamaAnnotatorResponseFormatRequestedProperties `json:"properties"`
Required []string `json:"required"`
}
type OllamaAnnotatorResponseFormatRequestedProperties struct {
ModelFeedbackText OllamaAnnotatorResponseFormatRequestedProperty `json:"model_feedback"`
ProcessedNumber OllamaAnnotatorResponseFormatRequestedProperty `json:"processed_number"`
ProcessedText OllamaAnnotatorResponseFormatRequestedProperty `json:"processed_text"`
YesNoQuestionAnswer OllamaAnnotatorResponseFormatRequestedProperty `json:"question_answer"`
}
type OllamaAnnotatorResponseFormatRequestedProperty struct {
Type string `json:"type"`
}
var OllamaAnnotatorResponseRequestedFormat = OllamaAnnotatorResponseFormat{
Type: "object",
Properties: OllamaAnnotatorResponseFormatRequestedProperties{
ModelFeedbackText: OllamaAnnotatorResponseFormatRequestedProperty{
Type: "string",
},
ProcessedNumber: OllamaAnnotatorResponseFormatRequestedProperty{
Type: "integer",
},
ProcessedText: OllamaAnnotatorResponseFormatRequestedProperty{
Type: "string",
},
YesNoQuestionAnswer: OllamaAnnotatorResponseFormatRequestedProperty{
Type: "boolean",
},
},
Required: []string{"model_feedback", "processed_number", "processed_text", "question_answer"},
}
func (a OllamaAnnotator) Name() string {
return reflect.TypeOf(a).Name()
}
func (a OllamaAnnotator) GetDefaultFields() (s []string) {
for f := range FormatAsAPMessage(OllamaAnnotatorResponse{}, a.Name()) {
s = append(s, f)
}
sort.Strings(s)
return s
}
func (a OllamaAnnotator) Configured() bool {
return !reflect.DeepEqual(a, OllamaAnnotator{})
}
func (a OllamaAnnotator) Annotate(m APMessage) (APMessage, error) {
msg := GetAPMessageCommonFieldAsString(m, "MessageText")
// If message is blank, return
if regexp.MustCompile(emptyStringRegex).MatchString(msg) {
log.Debug(Aside("%s: message was blank, not annotating", a.Name()))
return m, nil
}
if a.Model == "" || a.UserPrompt == "" {
return m, fmt.Errorf("model and prompt are required to use the Ollama annotator")
}
url, err := url.Parse(a.URL)
if err != nil {
return m, fmt.Errorf("url could not be parsed: %s", err)
}
httpClient := &http.Client{}
if a.APIKey != "" {
httpClient = &http.Client{
Transport: &apiHeaderTransport{
key: a.APIKey,
base: http.DefaultTransport,
},
}
}
client := api.NewClient(url, httpClient)
if err != nil {
return m, fmt.Errorf("error creating http client: %s", err)
}
if a.SystemPrompt != "" {
OllamaAnnotatorFirstInstructions = a.SystemPrompt
}
if a.Timeout != 0 {
OllamaAnnotatorTimeout = a.Timeout
}
if a.MaxRetryAttempts != 0 {
OllamaAnnotatorMaxRetryAttempts = a.MaxRetryAttempts
}
if a.MaxRetryDelaySeconds != 0 {
OllamaAnnotatorRetryDelaySeconds = a.MaxRetryDelaySeconds
}
stream := false
requestedFormatJson, err := json.Marshal(OllamaAnnotatorResponseRequestedFormat)
if err != nil {
return m, fmt.Errorf("error setting Ollama response format: %s", err)
}
opts := map[string]any{}
for _, opt := range a.Options {
opts[opt.Name] = opt.Value
}
req := &api.GenerateRequest{
Model: a.Model,
Format: requestedFormatJson,
System: OllamaAnnotatorFirstInstructions + a.UserPrompt +
OllamaAnnotatorFinalInstructions,
Stream: &stream,
Prompt: `Here is the message to evaluate:\n` + msg,
Options: opts,
}
var r OllamaAnnotatorResponse
respFunc := func(resp api.GenerateResponse) error {
// Parse the JSON payload (hopefully)
rex := regexp.MustCompile(`\{[^{}]+\}`)
matches := rex.FindAllStringIndex(resp.Response, -1)
// Find the last json payload in case the model reasons about
// one in the middle of thinking
if len(matches) == 0 {
return fmt.Errorf("did not find a json object in response: %s", resp.Response)
}
start, end := matches[len(matches)-1][0], matches[len(matches)-1][1]
content := resp.Response[start:end]
content = SanitizeJSONString(content)
err = json.Unmarshal([]byte(content), &r)
if err != nil || resp.DoneReason != "stop" {
err = fmt.Errorf("%s done, reason %s, full response from Ollama: %s", resp.DoneReason, err, Aside(strings.ReplaceAll(resp.Response, "\n", "\t")))
return err
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(OllamaAnnotatorTimeout)*time.Second)
defer cancel()
log.Debug(Aside("%s: annotating message ending in \"", a.Name()),
Note(Last20Characters(msg)),
Aside("\", model "),
Note(a.Model))
err = retry.Do(func() error {
err = client.Generate(ctx, req, respFunc)
if err != nil {
return &RetriableError{
Err: err,
RetryAfter: time.Duration(OllamaAnnotatorRetryDelaySeconds) * time.Second,
}
}
return nil
},
retry.Attempts(uint(OllamaAnnotatorMaxRetryAttempts)),
retry.DelayType(retry.BackOffDelay),
)
if (r == OllamaAnnotatorResponse{}) {
log.Debug(Aside("%s: response was empty", a.Name()))
return m, nil
} else {
// This ensures the field is never zero
r.ProcessedNumber = min(100, max(1, r.ProcessedNumber))
return MergeAPMessages(FormatAsAPMessage(r, a.Name()), m), nil
}
}
// ALSO USED WITH OLLAMA FILTER
// apiHeaderTransport wraps the default RoundTripper to inject the auth header
// needed because ollama package doesn't support APIKeys natively
type apiHeaderTransport struct {
key string
base http.RoundTripper
}
func (t *apiHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.key)
return t.base.RoundTrip(req)
}