Skip to content

Commit fb72ba0

Browse files
committed
softmax
1 parent 540d123 commit fb72ba0

5 files changed

Lines changed: 98 additions & 4 deletions

File tree

include/TensorSANN/activations/ReLU.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ class ReLU : Layer{
1313

1414
Tensor backward(const Tensor &grad_output) override;
1515

16-
protected:
17-
bool isTrainable_ = false;
18-
1916
};
2017

2118
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "TensorSANN/layers/Layer.hpp"
4+
#include "TensorSANN/utils/Tensor.hpp"
5+
#include <memory>
6+
#include <string>
7+
8+
namespace TensorSANN {
9+
10+
class Softmax : Layer{
11+
public:
12+
13+
Tensor forward(const Tensor &input) override;
14+
15+
Tensor backward(const Tensor &grad_output) override;
16+
17+
protected:
18+
Tensor output_;
19+
20+
};
21+
22+
23+
24+
}

src/activations/ReLU.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
namespace TensorSANN{
55

66
Tensor ReLU::forward(const Tensor &input){
7+
cachedInput_ = input;
78
Tensor output = input;
89

910
for (size_t i = 0; i < output.size(); ++i){

src/activations/Softmax.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "TensorSANN/activations/Softmax.hpp"
2+
#include "TensorSANN/utils/Tensor.hpp"
3+
#include <algorithm>
4+
#include <numeric>
5+
#include <cmath>
6+
#include <cassert>
7+
8+
namespace TensorSANN{
9+
10+
Tensor Softmax::forward(const Tensor &input){
11+
cachedInput_ = input;
12+
Tensor output = input;
13+
std::vector<float>& input_data = output.data();
14+
float max_val = *std::max_element(input_data.begin(), input_data.end());
15+
16+
for (size_t i = 0; i < output.size(); ++i){
17+
output.data()[i] = std::max(0.0f, output.data()[i]);
18+
}
19+
20+
21+
std::vector<float> exponents;
22+
exponents.reserve(input_data.size());
23+
for (float val : input_data) {
24+
exponents.push_back(std::exp(val - max_val));
25+
}
26+
27+
// Sum of exponents for normalization
28+
float sum = std::accumulate(exponents.begin(), exponents.end(), 0.0f);
29+
30+
// Normalize to get probabilities
31+
for (size_t i = 0; i < input_data.size(); ++i) {
32+
input_data[i] = exponents[i] / sum;
33+
}
34+
35+
this->output_ = output;
36+
return output;
37+
38+
}
39+
40+
Tensor Softmax::backward(const Tensor &grad_output) {
41+
// assert(this->output.size() == grad_output.size() && "Gradient size mismatch");
42+
43+
const std::vector<float>& S = this->output_.data();
44+
const std::vector<float>& dL_dS = grad_output.data();
45+
std::vector<float> dL_dz(S.size());
46+
47+
// Compute element-wise product of gradient and softmax output
48+
float sum_grad = 0.0f;
49+
for (size_t i = 0; i < S.size(); ++i) {
50+
sum_grad += dL_dS[i] * S[i];
51+
}
52+
53+
// Compute final gradients
54+
for (size_t i = 0; i < S.size(); ++i) {
55+
dL_dz[i] = S[i] * (dL_dS[i] - sum_grad);
56+
}
57+
58+
return Tensor((output_).shape(), dL_dz);
59+
60+
}
61+
} // namespace TensorSANN

src/model_main.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "TensorSANN/utils/Tensor.hpp"
33
#include "TensorSANN/layers/DenseLayer.hpp"
44

5+
#include "TensorSANN/activations/Softmax.hpp"
56
#include "TensorSANN/activations/ReLU.hpp"
67
#include "TensorSANN/optimizers/SGD.hpp"
78

@@ -22,7 +23,17 @@ int main() {
2223
std::cout << d1w.to_string() << std::endl;
2324
std::cout << d1b.to_string() << std::endl;
2425

25-
dense1.forward(input_tensor.transpose());
26+
TensorSANN::Tensor f1 = dense1.forward(input_tensor.transpose());
27+
std::cout << f1.to_string() << std::endl;
28+
29+
TensorSANN::Softmax smax;
30+
31+
TensorSANN::Tensor f2 = smax.forward(f1);
32+
std::cout << f2.to_string() << std::endl;
33+
34+
TensorSANN::Tensor b2 = smax.backward(f2);
35+
std::cout << b2.to_string() << std::endl;
36+
2637
return 0;
2738

2839
}

0 commit comments

Comments
 (0)