Skip to content

Commit d671caa

Browse files
committed
prob not working
1 parent c6cc0a7 commit d671caa

File tree

5 files changed

+32
-17
lines changed

5 files changed

+32
-17
lines changed

include/TensorSANN/layers/DenseLayer.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99

1010
namespace TensorSANN {
1111

12-
class DenseLayer : TrainableLayer{
12+
class DenseLayer : public TrainableLayer{
1313
public:
1414
DenseLayer(size_t input_size, size_t output_size);
1515

1616
Tensor forward(const Tensor &input) override;
1717

1818
Tensor backward(const Tensor &output_grad) override;
1919

20+
void update_weights_biases(float learning_rate) override;
2021
};
2122

2223

include/TensorSANN/layers/TrainableLayer.hpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,27 @@
99

1010
namespace TensorSANN {
1111

12-
class TrainableLayer : Layer{
12+
class TrainableLayer : public Layer{
1313
public:
14-
virtual Tensor forward(const Tensor &input);
14+
virtual ~TrainableLayer() = default;
1515

16-
virtual Tensor backward(const Tensor &output_grad);
16+
virtual Tensor forward(const Tensor &input);
1717

18-
Tensor weights() {return weights_;}
19-
Tensor biases() {return biases_;}
20-
const Tensor weights() const {return weights_;}
21-
const Tensor biases() const {return biases_;}
18+
virtual Tensor backward(const Tensor &output_grad);
2219

23-
protected:
24-
Tensor weights_;
25-
Tensor biases_;
26-
// Tensor cachedInput_;
27-
};
20+
virtual void update_weights_biases(float learning_rate);
21+
22+
Tensor weights() {return weights_;}
23+
Tensor biases() {return biases_;}
24+
const Tensor weights() const {return weights_;}
25+
const Tensor biases() const {return biases_;}
2826

2927
public:
30-
bool isTrainable_ = true;
28+
bool isTrainable_ = true; // note might be redudnat
29+
Tensor weights_;
30+
Tensor biases_;
31+
Tensor cachedInput_;
3132

33+
};
3234

3335
}

src/layers/DenseLayer.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
#include "TensorSANN/utils/Tensor.hpp"
44

55
namespace TensorSANN{
6-
DenseLayer::DenseLayer(size_t input_size, size_t output_size) : weights_(Tensor({input_size, output_size})), biases_(Tensor({output_size})) {
7-
// normal distrubtion random that i got from chat gpt
6+
DenseLayer::DenseLayer(size_t input_size, size_t output_size){
7+
weights_ = Tensor({input_size, output_size});
8+
biases_ = Tensor({output_size});
9+
810
std::random_device rd;
911
std::mt19937 gen(rd());
1012
std::normal_distribution<float> dist(0.0f, 0.1f);
@@ -47,4 +49,10 @@ namespace TensorSANN{
4749
// return d_Z to continue back prop
4850
return d_Z;
4951
}
52+
53+
void DenseLayer::update_weights_biases(float learning_rate){
54+
weights_ = weights_ - ((*(weights_.grad())) * learning_rate);
55+
biases_ = biases_ - ((*(biases_.grad())) * learning_rate);
56+
57+
}
5058
} // namespace TensorSANN

src/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <iostream>
22
#include "TensorSANN/utils/Tensor.hpp"
33
#include "TensorSANN/layers/DenseLayer.hpp"
4+
#include "TensorSANN/layers/TrainableLayer.hpp"
45
#include "TensorSANN/activations/ReLU.hpp"
56
#include "TensorSANN/optimizers/SGD.hpp"
67

src/optimizers/SGD.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ bool SGD::update(Layer &layer){
1111

1212
// if it is then we can attempt to downcast;
1313
TrainableLayer *t_layer = dynamic_cast<TrainableLayer*>(&layer);
14+
if (!t_layer) return false; // just in case downcast fails
15+
16+
// update weights
17+
(*t_layer).update_weights_biases(learning_rate);
1418

15-
1619
return true;
1720
}
1821

0 commit comments

Comments
 (0)