Skip to content

Commit c6cc0a7

Browse files
committed
more project strucutre
1 parent ed2b091 commit c6cc0a7

8 files changed

Lines changed: 58 additions & 21 deletions

File tree

include/TensorSANN/activations/ReLU.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ namespace TensorSANN {
99

1010
class ReLU : Layer{
1111
public:
12-
13-
1412
Tensor forward(const Tensor &input) override;
1513

1614
Tensor backward(const Tensor &grad_output) override;
17-
};
1815

16+
protected:
17+
bool isTrainable_ = false;
18+
19+
};
1920

2021
}

include/TensorSANN/layers/DenseLayer.hpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "TensorSANN/layers/Layer.hpp"
3+
#include "TensorSANN/layers/TrainableLayer.hpp"
44
#include "TensorSANN/utils/Tensor.hpp"
55
#include <memory>
66
#include <random>
@@ -9,23 +9,14 @@
99

1010
namespace TensorSANN {
1111

12-
class DenseLayer : Layer{
12+
class DenseLayer : TrainableLayer{
1313
public:
1414
DenseLayer(size_t input_size, size_t output_size);
1515

1616
Tensor forward(const Tensor &input) override;
1717

1818
Tensor backward(const Tensor &output_grad) override;
1919

20-
Tensor weights() {return weights_;}
21-
Tensor biases() {return biases_;}
22-
const Tensor weights() const {return weights_;}
23-
const Tensor biases() const {return biases_;}
24-
25-
protected:
26-
Tensor weights_;
27-
Tensor biases_;
28-
// Tensor cachedInput_;
2920
};
3021

3122

include/TensorSANN/layers/Layer.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ class Layer{
1919
protected:
2020
// Tensor output_;
2121
Tensor cachedInput_;
22-
};
2322

2423

24+
public:
25+
bool isTrainable_ = false;
26+
};
27+
2528
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include "TensorSANN/layers/Layer.hpp"
4+
#include "TensorSANN/utils/Tensor.hpp"
5+
#include <memory>
6+
#include <random>
7+
#include <vector>
8+
#include <string>
9+
10+
namespace TensorSANN {
11+
12+
class TrainableLayer : Layer{
13+
public:
14+
virtual Tensor forward(const Tensor &input);
15+
16+
virtual Tensor backward(const Tensor &output_grad);
17+
18+
Tensor weights() {return weights_;}
19+
Tensor biases() {return biases_;}
20+
const Tensor weights() const {return weights_;}
21+
const Tensor biases() const {return biases_;}
22+
23+
protected:
24+
Tensor weights_;
25+
Tensor biases_;
26+
// Tensor cachedInput_;
27+
};
28+
29+
public:
30+
bool isTrainable_ = true;
31+
32+
33+
}

include/TensorSANN/optimizers/Optimizer.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ namespace TensorSANN{
88

99
class Optimizer{
1010
public:
11-
virtual void update(std::vector<Layer> & layers);
11+
virtual ~Optimizer() = default;
12+
virtual bool update(Layer & layer);
1213
};
1314

1415
}// namespace TensorSANN

include/TensorSANN/optimizers/SGD.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class SGD : Optimizer{
1111
public:
1212
SGD(float learning_rate);
1313

14-
void update(std::vector<Layer> & layers) override;
14+
bool update(Layer & layer) override;
1515

1616
private:
1717
float learning_rate;

src/main.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#include "TensorSANN/utils/Tensor.hpp"
33
#include "TensorSANN/layers/DenseLayer.hpp"
44
#include "TensorSANN/activations/ReLU.hpp"
5+
#include "TensorSANN/optimizers/SGD.hpp"
56

67
int main() {
78
std::cout << "Hello world!" << std::endl;
9+
TensorSANN::SGD optimizer = TensorSANN::SGD(0.01f);
810

911
std::vector<size_t> shape1 = {2, 3}; // 2x3 tensor
1012
TensorSANN::Tensor tensor_1(shape1);

src/optimizers/SGD.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
#include "TensorSANN/optimizers/SGD.hpp"
2+
#include "TensorSANN/layers/TrainableLayer.hpp"
23

34
namespace TensorSANN{
45

56
SGD::SGD(float learning_rate) : learning_rate(learning_rate) {}
67

7-
void SGD::update(std::vector<Layer> &layers){
8-
for (auto& layer : layers){
9-
// for (size_t i = 0; i < )
10-
}
8+
bool SGD::update(Layer &layer){
9+
// check if trainable
10+
if (!layer.isTrainable_) return false;
11+
12+
// if it is then we can attempt to downcast;
13+
TrainableLayer *t_layer = dynamic_cast<TrainableLayer*>(&layer);
14+
15+
16+
return true;
1117
}
1218

1319
}

0 commit comments

Comments
 (0)