Skip to content

Commit 1ccf9fd

Browse files
committed
Refactor neural network training and initialization
- Updated the optimizer to include SGDMomentum with momentum support. - Enhanced the DenseLayer to accept weight and bias initializers. - Introduced various initializers: Constant, HeNormal, XavierUniform, and Zeros. - Modified the CPU and GPU backends to support random normal initialization. - Adjusted the neural network serialization to accommodate new layer structures. - Updated integration tests to validate changes in optimizer and layer initialization.
1 parent 625042f commit 1ccf9fd

19 files changed

Lines changed: 516 additions & 274 deletions

File tree

examples/mnist/cpu.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use meuron::cost::CrossEntropy;
44
use meuron::layer::DenseLayer;
55
use meuron::metric::classification::accuracy;
66
use meuron::optimizer::SGD;
7+
use meuron::initializer::{HeNormal, Zeros, XavierUniform};
78
use meuron::train::TrainOptions;
89
use meuron::{Layers, NetworkType, NeuralNetwork};
910
use ndarray::Array2;
@@ -38,7 +39,7 @@ fn ensure_mnist(dir: &Path) -> io::Result<()> {
3839

3940
let response = ureq::get(&url)
4041
.call()
41-
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
42+
.map_err(|e| io::Error::other(e.to_string()))?;
4243

4344
let mut body = response.into_body();
4445
let mut gz = GzDecoder::new(body.as_reader());
@@ -111,8 +112,8 @@ fn main() {
111112
println!("Creating new model...");
112113
NeuralNetwork::new(
113114
Layers![
114-
DenseLayer::new(28 * 28, 128, ReLU),
115-
DenseLayer::new(128, 10, Softmax)
115+
DenseLayer::new(28 * 28, 128, ReLU, HeNormal, Zeros),
116+
DenseLayer::new(128, 10, Softmax, XavierUniform, Zeros)
116117
],
117118
CrossEntropy,
118119
)

0 commit comments

Comments
 (0)