-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCNN.h
More file actions
109 lines (88 loc) · 3.3 KB
/
CNN.h
File metadata and controls
109 lines (88 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//
// Created by 峰池 on 2017/10/7.
//
#ifndef HANDBUILTNN_CNN_H
#define HANDBUILTNN_CNN_H
#include <iostream>
#include <string>
#include <vector>
#include <cmath>
#include <algorithm>
#include <fstream>
#include <numeric>
#include <random>
#include <functional>
typedef std::vector<float> FLTARY;
class ConvLayer{
public:
FLTARY _weights;
FLTARY _bias;
FLTARY _inData;
uint32_t _inSize, _outSize;
uint32_t _inWidth, _outWidth;
uint32_t _filterWidth;
template<typename _Gen>
uint32_t Initialize(uint32_t nInSize, uint32_t nOutSize, uint32_t nInWidth,
_Gen gen, uint32_t filterWidth = 3) {
_inSize = nInSize;
_outSize = nOutSize;
_filterWidth = filterWidth;
_weights.resize(_inSize * _filterWidth * _filterWidth * _outSize);
_bias.resize(_outSize);
_inWidth = nInWidth;
_outWidth = _inWidth - _filterWidth + 1;
std::generate(_weights.begin(), _weights.end(), gen);
std::generate(_bias.begin(), _bias.end(), gen);
return _outWidth;
}
void Update(const FLTARY &weightsGrads, const FLTARY &biasGrads, float lr) {
// i : input size
// m : filter width
// n : filter height
// j : output size
uint32_t jCount = _filterWidth * _filterWidth * _outSize;
uint32_t mCount = _filterWidth * _outSize;
uint32_t nCount = _outSize;
for (uint32_t i = 0; i < _inSize; ++i) {
for (uint32_t m = 0; m < _filterWidth; ++m) {
for (uint32_t n = 0; n < _filterWidth; ++n) {
for (uint32_t j = 0; j < _outSize; ++j) {
_weights[j * jCount + m * mCount + n * nCount + i] -= lr * weightsGrads[j * jCount + m * mCount + n * nCount + i];
}
}
}
}
for (uint32_t j = 0; j < _outSize; ++j) {
_bias[j] -= lr * biasGrads[j];
}
}
void Forward(const FLTARY &inData, FLTARY &outData) {
_inData = inData;
// j : output size
outData.resize(_outSize * _outWidth * _outWidth);
uint32_t jOutCount = _outWidth * _outWidth;
uint32_t iInCount = _inWidth * _inWidth;
uint32_t jWeightCount = _filterWidth * _filterWidth * _outSize;
uint32_t mWeightCount = _filterWidth * _outSize;
uint32_t nWeightCount = _outSize;
// 使用非数组的方式太难写了。。。。
for (uint32_t j = 0; j < _outSize; j++) {
for (uint32_t p = 0; p < _outWidth; p++){
for (uint32_t q = 0; q < _outWidth; q++){
// 三层数据迭代
for (uint32_t m = 0; m < _filterWidth; m++){
for (uint32_t n = 0; n < _filterWidth; n++){
for (uint32_t i = 0; i < _inSize; i++)
outData[j * jOutCount + p * _outWidth + q] += _weights[j * jWeightCount + m * mWeightCount + n * nWeightCount + i] \
* _inData[i * iInCount + (p + m) * _inWidth + (n + q)];
}
}
}
}
}
}
void Backward(const FLTARY &topGrads, FLTARY &bottomGrads,
FLTARY &weightsGrads, FLTARY &biaGrads) {
}
};
#endif //HANDBUILTNN_CNN_H