Skip to content

Commit 3d23c8e

Browse files
committed
working mnist dataloader
1 parent 1895b4b commit 3d23c8e

4 files changed

Lines changed: 28 additions & 19 deletions

File tree

include/TensorSANN/data/Dataset.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
template<typename T>
44
class Dataset {
55
public:
6-
virtual size_t size() const = 0;
7-
virtual std::pair<std::vector<T>, int> get_item(size_t index) const = 0;
6+
virtual int size() const = 0;
7+
// virtual std::pair<T, int> get_item(int index) const = 0;
88
virtual ~Dataset() = default;
99
};

include/TensorSANN/data/MNISTDataset.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77

88
namespace TensorSANN{
99

10-
class MNISTDataset : public Dataset<float>{
10+
class MNISTDataset : public Dataset<Tensor>{
1111
protected:
1212
std::vector<Tensor> images;
1313
std::vector<unsigned int> labels;
1414

1515
public:
1616
MNISTDataset(std::string labels_path, std::string images_path);
17-
size_t size() const override;
18-
std::pair<Tensor, int> get_item(size_t index) const override;
17+
int size() const override;
18+
// std::pair<Tensor, int> get_item(int index) const override;
1919
void print_mnist_index(int index);
2020

2121

src/data/MNISTDataset.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ static std::vector<Tensor> load_mnist_images(std::string path){
6868

6969
std::vector<Tensor> images;
7070
for (int i = 0; i < num_images; i++) {
71-
std::vector<float> buffer(rows * cols);
71+
std::vector<uint8_t> buffer(rows * cols);
7272
file.read(reinterpret_cast<char*>(buffer.data()), rows * cols);
73-
73+
std::vector<float> float_buffer(buffer.begin(), buffer.end());
7474
//@@ temporary
75-
images.push_back(Tensor({rows * cols, 1}, buffer) / 255.0);
75+
images.push_back(Tensor({rows * cols, 1}, float_buffer) / 255.0);
7676
}
7777
file.close();
7878
return images;
@@ -85,12 +85,12 @@ MNISTDataset::MNISTDataset(std::string labels_path, std::string images_path){
8585
}
8686

8787

88-
size_t MNISTDataset::size(){
89-
return 0;
90-
}
91-
std::pair<Tensor, int> MNISTDataset::get_item(size_t index){
92-
88+
int MNISTDataset::size() const{
89+
return 101;
9390
}
91+
// std::pair<Tensor, int> MNISTDataset::get_item(int index) const{
92+
// return nullptr;
93+
// }
9494

9595
static char pixel_to_char(float pixel_value) {
9696
if (pixel_value > 0.8f) return '#'; // Dark pixel
@@ -100,16 +100,17 @@ static char pixel_to_char(float pixel_value) {
100100
else return ' '; // Very light pixel (almost white)
101101
}
102102

103-
void TensorSANN::print_mnist_index(int index){
104-
auto image = images[index].data;
103+
void MNISTDataset::print_mnist_index(int index){
104+
auto image = images[index].data();
105105
int label = labels[index];
106106
std::cout << "Index: " << index << " Label: " << label << std::endl;
107107

108108
for (int i = 0; i < MINIST_IMG_DIM; i++) {
109109
for (int j = 0; j < MINIST_IMG_DIM; j++) {
110-
printf("%c ", pixel_to_char(image[i * MINIST_IMG_DIM + j]));
110+
std::cout << pixel_to_char(image[i * MINIST_IMG_DIM + j]);
111+
// std::cout << (iemage[i * MINIST_IMG_DIM + j]);
111112
}
112-
printf("\n");
113+
std::cout<< std::endl;
113114
}
114115

115116
}

src/model_main.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <iostream>
22
#include <vector>
33
#include <memory>
4+
#include <string>
45

56
#include "TensorSANN/utils/Tensor.hpp"
67
#include "TensorSANN/layers/DenseLayer.hpp"
@@ -10,9 +11,16 @@
1011
#include "TensorSANN/activations/ReLU.hpp"
1112
#include "TensorSANN/optimizers/SGD.hpp"
1213

14+
#include "TensorSANN/data/MNISTDataset.hpp"
15+
1316
int main() {
1417
std::cout << "MODEL MAIN CPP ==========" << std::endl;
15-
18+
19+
std::string train_label_path = "../lib/datasets/mnist/t10k-labels.idx1-ubyte";
20+
std::string train_images_path = "../lib/datasets/mnist/t10k-images.idx3-ubyte";
21+
TensorSANN::MNISTDataset md(train_label_path, train_images_path);
22+
23+
md.print_mnist_index(0);
1624

1725
std::vector<size_t> shape2 = {16,1}; // 2x3 tensor
1826
std::vector<float> data = {
@@ -43,7 +51,7 @@ int main() {
4351
layers.push_back(std::make_shared<TensorSANN::Softmax>());
4452

4553

46-
int epoch = 2000;
54+
int epoch = 1;
4755
for (int i = 1; i <= epoch; ++i){
4856
TensorSANN::Tensor fwd_op = input_tensor.transpose();
4957
// fwd

0 commit comments

Comments
 (0)