-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.go
More file actions
210 lines (187 loc) · 6.98 KB
/
models.go
File metadata and controls
210 lines (187 loc) · 6.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
package gopheract
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"github.com/mitchellh/mapstructure"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
)
// Base LLM interface
type LLM interface {
StructuredChat(any, any) (string, error)
}
// Implementation of LLM for OpenAI
type OpenAILLM struct {
// The OpenAI model to use
Model openai.ChatModel
// OpenAI API client
Client *openai.Client
}
// Constructor function for a new OpenAILLM (provide an API key and the model identifier)
func NewOpenAILLM(apiKey, model string) *OpenAILLM {
client := openai.NewClient(option.WithAPIKey(apiKey))
return &OpenAILLM{
Model: model,
Client: &client,
}
}
// Produce a structured response, given a response format (struct type) and a chat history.
//
// Since this implementation is for the OpenAILLM, the chat history is validate as a list of OpenAI chat messages
func (o *OpenAILLM) StructuredChat(chatHistory any, responseFormat any) (string, error) {
typedChatHistory, ok := chatHistory.([]openai.ChatCompletionMessageParamUnion)
if !ok {
return "", errors.New("chat history does not conform to the expected OpenAI format")
}
resFmt, ok := responseFormat.(openai.ChatCompletionNewParamsResponseFormatUnion)
if !ok {
return "", errors.New("response format doesn't conform whith the one expected for OpenAI")
}
ctx := context.Background()
chat, err := o.Client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: typedChatHistory,
Model: o.Model,
ResponseFormat: resFmt,
})
if err != nil {
return "", err
}
return chat.Choices[0].Message.Content, nil
}
// Struct type representing the thinking part of the ReAct agent
type Thought struct {
Thought string `json:"thought" jsonschema_description:"Thought about the path forward, based on the chat history"`
}
// Struct type representing the observation part of the ReAct agent
type Observation struct {
Observation string `json:"observation" jsonschema_description:"Observation about the current state of things, based on the chat history"`
}
// Struct type representing the reason why the agent terminated its loop
type StopReason struct {
Reason string `json:"reason" jsonschema_description:"Reason why the conversation should stop"`
}
// Struct type representing the arguments of a tool call.
//
// Given typing constraints, the `ParameterValue` field is a string meant to represent serialized JSON data
type ToolCallArgs struct {
ParameterValue string `json:"parameter_value" jsonschema_description:"Parameter name and value of the parameter as a JSON string (e.g. '{'age': 40, 'name': 'John Doe'}')"`
}
// Struct type representint a tool call
type ToolCall struct {
Name string `json:"name" jsonschema_description:"Name of the tools to call"`
Args []ToolCallArgs `json:"args" jsonschema_description:"Tool call arguments"`
}
// Helper method to convert the arguments of a ToolCall (a slice of `ToolCallArgs`) to a map
func (t *ToolCall) ArgsToMap() (map[string]any, error) {
args := map[string]any{}
for _, arg := range t.Args {
var unmar map[string]any
err := json.Unmarshal([]byte(arg.ParameterValue), &unmar)
if err != nil {
return nil, err
}
for k := range unmar {
args[k] = unmar[k]
}
}
return args, nil
}
// Struct type representing the action part of a ReAct Agent
//
// The agent can take two type of actions:
// (1) `_done`, in which case the Action payload will have a non-null `StopReason` field;
// (2) `tool_call`, in which case the Action payload will have a non-null `ToolCall` field
type Action struct {
ActionType string `json:"type" jsonschema:"enum=_done,enum=tool_call" jsonschema_description:"Type of the action to perform based on the chat history. Use '_done' if you think the conversation should stop, and 'tool_call' if you want to call a tool"`
StopReason *StopReason `json:"stop_reason" jsonschema_description:"Reason why the conversation should stop. Only present when type is '_done'"`
ToolCall *ToolCall `json:"tool_call" jsonschema_description:"Tool to call with its arguments. Only present when type is 'tool_call'"`
}
// Helper struct type to represent a message within the chat history
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Constructor function for a new chat message
func NewChatMessage(role, content string) *ChatMessage {
return &ChatMessage{
Role: role,
Content: content,
}
}
// Struct type representing metadata for tool parameters, used when passing the tool defintion to the agent's system prompt.
type ToolParamsMetadata struct {
JsonDef string
Description string
Type string
}
// Helper method to convert the `ToolParamsMetada` into a string
func (tp *ToolParamsMetadata) ToString() string {
return fmt.Sprintf("JSON Definition of the parameter: %s; Description: %s; Type: %s", tp.JsonDef, tp.Description, tp.Type)
}
// Type struct representing metadata related to a tool defintion
type ToolMetadata struct {
Name string
Description string
ParametersMetadata []ToolParamsMetadata
}
// Base interface that a tool definition should implement
type Tool interface {
GetMetadata() ToolMetadata
Execute(map[string]any) (any, error)
}
// Struct type representing a tool defintion that implements the `Tool` interface.
//
// The generic type T indicates the struct type representing the parameters of the tool function.
//
// A good practice for `ToolDefition` is to define the Name and the Description field as in detail and as explicitly as possibile.
type ToolDefinition[T any] struct {
Fn func(T) (any, error)
Name string
Description string
}
// Helper method to get the metadata from the tool definition.
func (t ToolDefinition[T]) GetMetadata() ToolMetadata {
fnType := reflect.TypeOf(t.Fn)
paramMeta := []ToolParamsMetadata{}
if fnType.NumIn() > 0 {
paramType := fnType.In(0)
for i := range paramType.NumField() {
field := paramType.Field(i)
jsonDef := field.Tag.Get("json")
desc := field.Tag.Get("description")
meta := ToolParamsMetadata{
JsonDef: jsonDef,
Description: desc,
Type: field.Type.String(),
}
paramMeta = append(paramMeta, meta)
}
}
return ToolMetadata{
Name: t.Name,
Description: t.Description,
ParametersMetadata: paramMeta,
}
}
// Method to execute the tool given the parameters received from the `ToolCall` action field.
//
// Thie method executes the following logic: (1) convers the parameters (passed as a map) to the original struct type for the tool defition (conversion happens based on the `json` tag); (2) calls the tool function with the converted parameters, returning its result.
func (t ToolDefinition[T]) Execute(params map[string]any) (any, error) {
var typedParams T
config := &mapstructure.DecoderConfig{
TagName: "json",
Result: &typedParams,
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return nil, err
}
err = decoder.Decode(params)
if err != nil {
return nil, err
}
return t.Fn(typedParams)
}