File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -9,12 +9,13 @@ namespace TensorSANN {
99
1010class ReLU : Layer{
1111public:
12-
13-
1412 Tensor forward (const Tensor &input) override ;
1513
1614 Tensor backward (const Tensor &grad_output) override ;
17- };
1815
16+ protected:
17+ bool isTrainable_ = false ;
18+
19+ };
1920
2021}
Original file line number Diff line number Diff line change 11#pragma once
22
3- #include " TensorSANN/layers/Layer .hpp"
3+ #include " TensorSANN/layers/TrainableLayer .hpp"
44#include " TensorSANN/utils/Tensor.hpp"
55#include < memory>
66#include < random>
99
1010namespace TensorSANN {
1111
12- class DenseLayer : Layer {
12+ class DenseLayer : TrainableLayer {
1313public:
1414 DenseLayer (size_t input_size, size_t output_size);
1515
1616 Tensor forward (const Tensor &input) override ;
1717
1818 Tensor backward (const Tensor &output_grad) override ;
1919
20- Tensor weights () {return weights_;}
21- Tensor biases () {return biases_;}
22- const Tensor weights () const {return weights_;}
23- const Tensor biases () const {return biases_;}
24-
25- protected:
26- Tensor weights_;
27- Tensor biases_;
28- // Tensor cachedInput_;
2920};
3021
3122
Original file line number Diff line number Diff line change @@ -19,7 +19,10 @@ class Layer{
1919protected:
2020 // Tensor output_;
2121 Tensor cachedInput_;
22- };
2322
2423
24+ public:
25+ bool isTrainable_ = false ;
26+ };
27+
2528}
Original file line number Diff line number Diff line change 1+ #pragma once
2+
3+ #include " TensorSANN/layers/Layer.hpp"
4+ #include " TensorSANN/utils/Tensor.hpp"
5+ #include < memory>
6+ #include < random>
7+ #include < vector>
8+ #include < string>
9+
10+ namespace TensorSANN {
11+
12+ class TrainableLayer : Layer{
13+ public:
14+ virtual Tensor forward (const Tensor &input);
15+
16+ virtual Tensor backward (const Tensor &output_grad);
17+
18+ Tensor weights () {return weights_;}
19+ Tensor biases () {return biases_;}
20+ const Tensor weights () const {return weights_;}
21+ const Tensor biases () const {return biases_;}
22+
23+ protected:
24+ Tensor weights_;
25+ Tensor biases_;
26+ // Tensor cachedInput_;
27+ };
28+
29+ public:
30+ bool isTrainable_ = true ;
31+
32+
33+ }
Original file line number Diff line number Diff line change @@ -8,7 +8,8 @@ namespace TensorSANN{
88
99class Optimizer {
1010public:
11- virtual void update (std::vector<Layer> & layers);
11+ virtual ~Optimizer () = default ;
12+ virtual bool update (Layer & layer);
1213};
1314
1415}// namespace TensorSANN
Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ class SGD : Optimizer{
1111public:
1212 SGD (float learning_rate);
1313
14- void update (std::vector< Layer> & layers ) override ;
14+ bool update (Layer & layer ) override ;
1515
1616private:
1717 float learning_rate;
Original file line number Diff line number Diff line change 22#include " TensorSANN/utils/Tensor.hpp"
33#include " TensorSANN/layers/DenseLayer.hpp"
44#include " TensorSANN/activations/ReLU.hpp"
5+ #include " TensorSANN/optimizers/SGD.hpp"
56
67int main () {
78 std::cout << " Hello world!" << std::endl;
9+ TensorSANN::SGD optimizer = TensorSANN::SGD (0 .01f );
810
911 std::vector<size_t > shape1 = {2 , 3 }; // 2x3 tensor
1012 TensorSANN::Tensor tensor_1 (shape1);
Original file line number Diff line number Diff line change 11#include " TensorSANN/optimizers/SGD.hpp"
2+ #include " TensorSANN/layers/TrainableLayer.hpp"
23
34namespace TensorSANN {
45
56SGD::SGD (float learning_rate) : learning_rate(learning_rate) {}
67
7- void SGD::update (std::vector<Layer> &layers){
8- for (auto & layer : layers){
9- // for (size_t i = 0; i < )
10- }
8+ bool SGD::update (Layer &layer){
9+ // check if trainable
10+ if (!layer.isTrainable_ ) return false ;
11+
12+ // if it is then we can attempt to downcast;
13+ TrainableLayer *t_layer = dynamic_cast <TrainableLayer*>(&layer);
14+
15+
16+ return true ;
1117}
1218
1319}
You can’t perform that action at this time.
0 commit comments