This repository was archived by the owner on May 11, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathRidgeRegrTrainer.cs
More file actions
196 lines (185 loc) · 7.47 KB
/
RidgeRegrTrainer.cs
File metadata and controls
196 lines (185 loc) · 7.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
using RCNet.MathTools;
using RCNet.MathTools.MatrixMath;
using RCNet.MathTools.VectorMath;
using RCNet.Neural.Activation;
using System;
using System.Collections.Generic;
using System.Globalization;
namespace RCNet.Neural.Network.NonRecurrent.FF
{
/// <summary>
/// Implements the Ridge regression trainer of the feed forward network.
/// </summary>
/// <remarks>
/// <para>
/// The feed forward network to be trained must have no hidden layers and the Identity output activation.
/// </para>
/// </remarks>
[Serializable]
public class RidgeRegrTrainer : INonRecurrentNetworkTrainer
{
//Constants
private const double StopLambdaDifference = 1e-10;
//Attribute properties
/// <inheritdoc/>
public double MSE { get; private set; }
/// <inheritdoc/>
public int MaxAttempt { get; private set; }
/// <inheritdoc/>
public int Attempt { get; private set; }
/// <inheritdoc/>
public int MaxAttemptEpoch { get; private set; }
/// <inheritdoc/>
public int AttemptEpoch { get; private set; }
/// <inheritdoc/>
public string InfoMessage { get; private set; }
//Attributes
private readonly RidgeRegrTrainerSettings _cfg;
private readonly FeedForwardNetwork _net;
private readonly List<double[]> _inputVectorCollection;
private readonly List<double[]> _outputVectorCollection;
private readonly Matrix _XT;
private readonly Matrix _XTdotX;
private readonly Vector[] _XTdotY;
private readonly List<Vector> _outputSingleColVectorCollection;
private readonly ParamValFinder _lambdaFinder;
private double _currLambda;
//Constructor
/// <summary>
/// Creates an initialized instance.
/// </summary>
/// <param name="net">The FF network to be trained.</param>
/// <param name="inputVectorCollection">The input vectors (input).</param>
/// <param name="outputVectorCollection">The output vectors (ideal).</param>
/// <param name="cfg">The configuration of the trainer.</param>
public RidgeRegrTrainer(FeedForwardNetwork net,
List<double[]> inputVectorCollection,
List<double[]> outputVectorCollection,
RidgeRegrTrainerSettings cfg
)
{
//Check network readyness
if (!net.Finalized)
{
throw new InvalidOperationException($"Can´t create trainer. Network structure was not finalized.");
}
//Check network conditions
if (net.LayerCollection.Count != 1 || !(net.LayerCollection[0].Activation is AFAnalogIdentity))
{
throw new InvalidOperationException($"Can´t create trainer. Network structure is not complient (single layer having Identity activation).");
}
//Check samples conditions
if (inputVectorCollection.Count == 0)
{
throw new InvalidOperationException($"Can´t create trainer. Missing training samples.");
}
//Collections
_inputVectorCollection = new List<double[]>(inputVectorCollection);
_outputVectorCollection = new List<double[]>(outputVectorCollection);
//Parameters
_cfg = cfg;
MaxAttempt = _cfg.NumOfAttempts;
MaxAttemptEpoch = _cfg.NumOfAttemptEpochs;
Attempt = 1;
AttemptEpoch = 0;
_net = net;
_outputSingleColVectorCollection = new List<Vector>(_net.NumOfOutputValues);
for (int outputIdx = 0; outputIdx < _net.NumOfOutputValues; outputIdx++)
{
Vector outputSingleColVector = new Vector(outputVectorCollection.Count);
for (int row = 0; row < outputVectorCollection.Count; row++)
{
//Output
outputSingleColVector.Data[row] = outputVectorCollection[row][outputIdx];
}
_outputSingleColVectorCollection.Add(outputSingleColVector);
}
//Lambda seeker
_lambdaFinder = new ParamValFinder(_cfg.LambdaFinderCfg);
_currLambda = 0;
//Matrix setup
Matrix X = new Matrix(inputVectorCollection.Count, _net.NumOfInputValues + 1);
for (int row = 0; row < inputVectorCollection.Count; row++)
{
//Add constant bias
X.Data[row][0] = 1d;
//Add predictors
inputVectorCollection[row].CopyTo(X.Data[row], 1);
}
_XT = X.Transpose();
_XTdotX = _XT * X;
_XTdotY = new Vector[_net.NumOfOutputValues];
for (int outputIdx = 0; outputIdx < _net.NumOfOutputValues; outputIdx++)
{
_XTdotY[outputIdx] = _XT * _outputSingleColVectorCollection[outputIdx];
}
return;
}
//Properties
/// <inheritdoc/>
public INonRecurrentNetwork Net { get { return _net; } }
//Methods
/// <inheritdoc/>
public bool NextAttempt()
{
//Only one attempt makes the sense -> do nothhing and return false
return false;
}
/// <inheritdoc/>
public bool Iteration()
{
//Primary stop condition
if (AttemptEpoch == MaxAttemptEpoch)
{
return false;
}
//New lambda to be tested
double newLambda = _lambdaFinder.Next;
//Secondary stop condition
if (AttemptEpoch > 0 && Math.Abs(_currLambda - newLambda) < StopLambdaDifference)
{
return false;
}
//Next epoch allowed
_currLambda = newLambda;
++AttemptEpoch;
InfoMessage = $"lambda={_currLambda.ToString(CultureInfo.InvariantCulture)}";
//Inverse _XTdotX matrix
Matrix I;
if (_currLambda > 0)
{
Matrix B = new Matrix(_XTdotX);
double tmp = B.Data[0][0];
B.AddScalarToDiagonal(_currLambda);
B.Data[0][0] = tmp;
I = B.Inverse(true);
}
else
{
I = _XTdotX.Inverse(true);
}
//New weights buffer
double[] newWeights = new double[_net.NumOfWeights];
//Weights for each output neuron
for (int outputIdx = 0; outputIdx < _net.NumOfOutputValues; outputIdx++)
{
//Weights solution
Vector weights = I * _XTdotY[outputIdx];
//Store weights
//Bias
newWeights[_net.NumOfOutputValues * _net.NumOfInputValues + outputIdx] = weights.Data[0];
//Predictors
for (int i = 0; i < _net.NumOfInputValues; i++)
{
newWeights[outputIdx * _net.NumOfInputValues + i] = weights.Data[i + 1];
}
}
//Set new weights and compute error
_net.SetWeights(newWeights);
MSE = _net.ComputeBatchErrorStat(_inputVectorCollection, _outputVectorCollection).MeanSquare;
//Update lambda seeker
_lambdaFinder.ProcessError(MSE);
return true;
}
}//RidgeRegrTrainer
}//Namespace