-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpointer_test.go
More file actions
176 lines (147 loc) · 5.12 KB
/
pointer_test.go
File metadata and controls
176 lines (147 loc) · 5.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package knowledgesdk
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPointerUtils(t *testing.T) {
// 测试基本类型转指针
t.Run("测试基本类型转指针", func(t *testing.T) {
// String
s := "test"
sPtr := String(s)
assert.Equal(t, s, *sPtr)
// Int
i := 42
iPtr := Int(i)
assert.Equal(t, i, *iPtr)
// Float32
f := float32(3.14)
fPtr := Float32(f)
assert.Equal(t, f, *fPtr)
// Bool
b := true
bPtr := Bool(b)
assert.Equal(t, b, *bPtr)
// Time
now := time.Now()
tPtr := Time(now)
assert.Equal(t, now, *tPtr)
})
// 测试安全取值
t.Run("测试安全取值", func(t *testing.T) {
// 非nil值
s := String("test")
assert.Equal(t, "test", StringValue(s))
i := Int(42)
assert.Equal(t, 42, IntValue(i))
f := Float32(3.14)
assert.Equal(t, float32(3.14), Float32Value(f))
b := Bool(true)
assert.Equal(t, true, BoolValue(b))
// nil值
assert.Equal(t, "", StringValue(nil))
assert.Equal(t, 0, IntValue(nil))
assert.Equal(t, float32(0), Float32Value(nil))
assert.Equal(t, false, BoolValue(nil))
assert.Equal(t, time.Time{}, TimeValue(nil))
})
// 测试HasValue函数
t.Run("测试HasValue函数", func(t *testing.T) {
// 非nil值
assert.True(t, HasStringValue(String("test")))
assert.True(t, HasIntValue(Int(42)))
assert.True(t, HasFloat32Value(Float32(3.14)))
assert.True(t, HasBoolValue(Bool(true)))
assert.True(t, HasTimeValue(Time(time.Now())))
// nil值
assert.False(t, HasStringValue(nil))
assert.False(t, HasIntValue(nil))
assert.False(t, HasFloat32Value(nil))
assert.False(t, HasBoolValue(nil))
assert.False(t, HasTimeValue(nil))
})
}
func TestModelDefaults(t *testing.T) {
// 测试KnowledgeBase默认值
t.Run("测试KnowledgeBase默认值", func(t *testing.T) {
kb := &KnowledgeBase{}
kb.SetDefaults()
assert.Equal(t, "", StringValue(kb.ModelID))
assert.Equal(t, float32(0.7), Float32Value(kb.Temperature))
assert.Equal(t, false, BoolValue(kb.EnableRigorousAnswer))
assert.Equal(t, 1000, IntValue(kb.ChunkSize))
assert.Equal(t, 50, IntValue(kb.Overlap))
assert.Equal(t, 5, IntValue(kb.TopK))
assert.Equal(t, float32(0.6), Float32Value(kb.SimilarityThreshold))
assert.Equal(t, 3000, IntValue(kb.MaxReferenceLength))
})
// 测试Document默认值
t.Run("测试Document默认值", func(t *testing.T) {
doc := &Document{}
doc.SetDefaults()
assert.Equal(t, "pending", StringValue(doc.Status))
})
// 测试工厂函数
t.Run("测试工厂函数", func(t *testing.T) {
// NewKnowledgeBase
kb := NewKnowledgeBase("测试知识库", "测试描述")
assert.Equal(t, "测试知识库", StringValue(kb.Name))
assert.Equal(t, "测试描述", StringValue(kb.Description))
// 验证默认值已设置
assert.Equal(t, float32(0.7), Float32Value(kb.Temperature))
assert.Equal(t, 1000, IntValue(kb.ChunkSize))
// NewDocument
doc := NewDocument("kb123", "测试文档", "内容", "user123")
assert.Equal(t, "kb123", StringValue(doc.KBID))
assert.Equal(t, "测试文档", StringValue(doc.Name))
assert.Equal(t, "内容", StringValue(doc.OriginalContent))
assert.Equal(t, "user123", StringValue(doc.CreatorID))
// 验证默认值已设置
assert.Equal(t, "pending", StringValue(doc.Status))
// NewChunk
chunk := NewChunk("doc123", 1, "分块内容")
assert.Equal(t, "doc123", StringValue(chunk.DocumentID))
assert.Equal(t, 1, IntValue(chunk.ChunkIndex))
assert.Equal(t, "分块内容", StringValue(chunk.Content))
})
// 测试NewKnowledgeBaseWithDefaults
t.Run("测试NewKnowledgeBaseWithDefaults", func(t *testing.T) {
overrides := map[string]interface{}{
"temperature": float32(0.9),
"chunk_size": 1500,
"top_k": 10,
"model_id": "gpt-4",
}
kb := NewKnowledgeBaseWithDefaults("测试知识库", "测试描述", overrides)
assert.Equal(t, "测试知识库", StringValue(kb.Name))
assert.Equal(t, "测试描述", StringValue(kb.Description))
assert.Equal(t, float32(0.9), Float32Value(kb.Temperature))
assert.Equal(t, 1500, IntValue(kb.ChunkSize))
assert.Equal(t, 10, IntValue(kb.TopK))
assert.Equal(t, "gpt-4", StringValue(kb.ModelID))
// 验证未覆盖的默认值仍然存在
assert.Equal(t, 50, IntValue(kb.Overlap))
assert.Equal(t, float32(0.6), Float32Value(kb.SimilarityThreshold))
assert.Equal(t, 3000, IntValue(kb.MaxReferenceLength))
assert.Equal(t, false, BoolValue(kb.EnableRigorousAnswer))
})
}
func TestZeroValueUpdates(t *testing.T) {
// 这个测试验证我们的指针方案能够正确处理零值更新
t.Run("测试零值更新", func(t *testing.T) {
kb := NewKnowledgeBase("测试", "描述")
// 现在可以将字段设置为零值而不会被GORM忽略
kb.Temperature = Float32(0)
kb.ChunkSize = Int(0)
kb.EnableRigorousAnswer = Bool(false)
// 验证零值被正确设置
assert.Equal(t, float32(0), Float32Value(kb.Temperature))
assert.Equal(t, 0, IntValue(kb.ChunkSize))
assert.Equal(t, false, BoolValue(kb.EnableRigorousAnswer))
// 验证指针不为nil(这样GORM就会更新这些字段)
assert.NotNil(t, kb.Temperature)
assert.NotNil(t, kb.ChunkSize)
assert.NotNil(t, kb.EnableRigorousAnswer)
})
}