Skip to content
This repository was archived by the owner on Aug 17, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 20 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,51 +161,51 @@ fn main() {
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
];
// download_datasets(
// &datasets,
// "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com",
// );
// println!("{}", "Fashion MNIST dataset downloaded".to_string());
download_datasets(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really remember why this was commented out ...

&datasets,
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com",
);
println!("{}", "Fashion MNIST dataset downloaded".to_string());
// TODO avoid repeated effort here
unzip_datasets(&datasets);
println!("{}", "Fashion MNIST dataset decompressed".to_string());
}
_ => println!("{}", "Failed to download MNIST dataset!".to_string()),
}
} else if args.cmd_mnist {
#[cfg(all(feature = "cuda"))]
#[cfg(any(feature = "cuda", feature = "native"))]
run_mnist(
args.arg_model_name,
args.arg_batch_size,
args.arg_learning_rate,
args.arg_momentum,
);
#[cfg(not(feature = "cuda"))]
#[cfg(not(any(feature = "cuda", feature = "native")))]
{
println!(
"Right now, you really need cuda! Not all features are available for all backends and as such, this one -as of now - only works with cuda."
"Right now, you really need cuda or to build with native features! Not all features are available for all backends and as such, this one -as of now - only works with cuda or native."
);
panic!()
}
} else if args.cmd_fashion {
#[cfg(all(feature = "cuda"))]
#[cfg(any(feature = "cuda", feature = "native"))]
run_fashion(
args.arg_model_name,
args.arg_batch_size,
args.arg_learning_rate,
args.arg_momentum,
);
#[cfg(not(feature = "cuda"))]
#[cfg(not(any(feature = "cuda", feature = "native")))]
{
println!(
"Right now, you really need cuda! Not all features are available for all backends and as such, this one -as of now - only works with cuda."
"Right now, you really need cuda or to build with native features! Not all features are available for all backends and as such, this one -as of now - only works with cuda or native."
);
panic!()
}
}
}

#[cfg(all(feature = "cuda"))]
#[cfg(any(feature = "cuda", feature = "native"))]
fn run_mnist(
model_name: Option<String>,
batch_size: Option<usize>,
Expand Down Expand Up @@ -323,8 +323,10 @@ fn run_mnist(
classifier_cfg.add_layer(nll_cfg);

// set up backends
#[cfg(feature = "cuda")]
let backend = ::std::rc::Rc::new(Backend::<Cuda>::default().unwrap());
// let native_backend = ::std::rc::Rc::new(Backend::<Native>::default().unwrap());
#[cfg(feature = "native")]
let backend = ::std::rc::Rc::new(Backend::<Native>::default().unwrap());

// set up solver
let mut solver_cfg = SolverConfig {
Expand Down Expand Up @@ -373,7 +375,7 @@ fn run_mnist(
}
}

#[cfg(all(feature = "cuda"))]
#[cfg(any(feature = "cuda", feature = "native"))]
fn run_fashion(
model_name: Option<String>,
batch_size: Option<usize>,
Expand Down Expand Up @@ -478,8 +480,11 @@ fn run_fashion(
classifier_cfg.add_layer(nll_cfg);

// set up backends
// set up backends
#[cfg(feature = "cuda")]
let backend = ::std::rc::Rc::new(Backend::<Cuda>::default().unwrap());
// let native_backend = ::std::rc::Rc::new(Backend::<Native>::default().unwrap());
#[cfg(feature = "native")]
let backend = ::std::rc::Rc::new(Backend::<Native>::default().unwrap());

// set up solver
let mut solver_cfg = SolverConfig {
Expand Down