-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.go
More file actions
96 lines (78 loc) · 2.39 KB
/
training.go
File metadata and controls
96 lines (78 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package main
import (
"fmt"
"math"
"sync"
"time"
)
// words in the list of words, learningRate is the learning rate, epochs is the number of loops to train the model
func (w2v *Word2Vec) TrainModel(words []string, learningRate float64, epochs, workers int) {
chunkSize := len(words) / workers
split := make([][]string, workers)
for epoch := 0; epoch < epochs; epoch++ {
startTime := time.Now()
var wg sync.WaitGroup
totalLoss := 0.0
for w := 0; w < workers; w++ {
start := w * chunkSize
end := min(start+chunkSize, len(words))
split[w] = words[start:end]
wg.Add(1)
go func(wordsSubset []string) {
defer wg.Done()
for i, target := range wordsSubset {
start := max(0, i-windowSize)
end := min(len(wordsSubset), i+windowSize)
for j := start; j < end; j++ {
if i != j {
context := wordsSubset[j]
w2v.M.Lock()
loss := w2v.UpdateVectors(target, context, learningRate)
w2v.M.Unlock()
totalLoss += loss
}
}
}
}(split[w])
}
wg.Wait()
elapsedTime := time.Since(startTime).Round(time.Millisecond).Seconds() // Sexy way to get time in seconds with only 3 decimal places
fmt.Printf("Epoch %v out %v of took %v seconds and had a loss of %.3f\n", epoch+1, epochs, elapsedTime, totalLoss)
}
}
func sigmoid(x float64) float64 {
return 1 / (1 + math.Exp(-x))
}
func (w2v *Word2Vec) UpdateVectors(target, context string, learningRate float64) float64 {
targetVector, targetExists := w2v.Vectors[target]
contextVector, contextExists := w2v.Vectors[context]
if !targetExists || !contextExists {
return 0
}
// Compute dot product (similarity score)
dotProduct := 0.0
for i := 0; i < vectorSize; i++ {
dotProduct += targetVector[i] * contextVector[i]
}
// Compute probability using sigmoid
probability := sigmoid(dotProduct)
bce := 1.0 - probability // binary cross-entropy
loss := -math.Log(probability)
// Update word vectors using gradient descent
for i := 0; i < vectorSize; i++ {
grad := learningRate * bce
targetVector[i] += grad * contextVector[i]
contextVector[i] += grad * targetVector[i]
}
// Save the updated vectors
w2v.Vectors[target] = targetVector
w2v.Vectors[context] = contextVector
// Save the updated vectors to UpdatedVectors map
if targetExists {
w2v.UpdatedVectors[target] = targetVector
}
if contextExists {
w2v.UpdatedVectors[context] = contextVector
}
return loss
}