Skip to content

Commit 78e88b7

Browse files
committed
importing old mnist loader
1 parent a994cbd commit 78e88b7

2 files changed

Lines changed: 71 additions & 16 deletions

File tree

include/TensorSANN/data/MNISTDataset.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ namespace TensorSANN{
1010
class MNISTDataset : public Dataset<float>{
1111
protected:
1212
std::vector<Tensor> images;
13-
std::vector<int> labels;
13+
std::vector<unsigned int> labels;
1414

1515
public:
16+
MNISTDataset(std::string labels_path, std::string images_path);
1617
size_t size() const override;
1718
std::pair<Tensor, int> get_item(size_t index) const override;
19+
void print_mnist_index(int index);
1820

1921

2022
};

src/data/MNISTDataset.cpp

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,119 @@
11
#include "TensorSANN/data/MNISTDataset.hpp"
22
#include "TensorSANN/utils/Tensor.hpp"
3-
#include <fstream>
3+
44
#include <stdexcept>
5+
#include <iostream>
6+
#include <fstream>
7+
#include <vector>
58

69
#define MNIST_IMAGE_MNUM 2051
710
#define MNIST_LABEL_MNUM 2049
811
#define MINIST_IMG_DIM 28
912

1013

14+
namespace TensorSANN{
15+
16+
static std::vector<unsigned int> load_mnist_labels(std::string path){
17+
std::ifstream file(path, std::ios::binary);
18+
if (!file.is_open()) {
19+
throw std::invalid_argument("Could not open file.");
20+
}
1121

12-
protected:
13-
std::vector<Tensor> images;
14-
std::vector<int> labels;
22+
int magic_number = 0;
23+
file.read(reinterpret_cast<char*>(&magic_number), 4);
24+
magic_number = __builtin_bswap32(magic_number);
1525

26+
if (magic_number != MNIST_LABEL_MNUM) throw std::invalid_argument("Invalid magic nubmer.");
1627

17-
static std::vector<unsigned int> load_mnist_labels(std::string path){
28+
int num_labels = 0;
29+
file.read(reinterpret_cast<char*>(&num_labels), 4);
30+
num_labels = __builtin_bswap32(num_labels);
1831

19-
std::ifstream file(path, std::ios::binary);
20-
if (!file.is_open()) {
21-
throw std::invalid_argument("Could not open file: %s", path);
32+
std::vector<unsigned int> labels;
33+
for (int i = 0; i < num_labels; i++) {
34+
uint8_t label;
35+
file.read(reinterpret_cast<char*>(&label), 1);
36+
labels.push_back(label);
2237
}
2338

39+
file.close();
40+
41+
return labels;
42+
2443

2544
}
26-
static std::vector<Tensor> MNISTDataset::load_mnist_images(std::string path){
45+
static std::vector<Tensor> load_mnist_images(std::string path){
2746
std::ifstream file(path, std::ios::binary);
2847
if (!file.is_open()) {
29-
throw std::invalid_argument("Could not open file: %s", path);
48+
throw std::invalid_argument("Could not open file.");
3049
}
3150

3251
int magic_number = 0;
3352
file.read(reinterpret_cast<char*>(&magic_number), 4);
3453
magic_number = __builtin_bswap32(magic_number); // Convert from big-endian
3554

36-
if (magic_number != MNIST_IMAGE_MNUM) throw std::invalid_argume
55+
if (magic_number != MNIST_IMAGE_MNUM) throw std::invalid_argument("Invalid magic number.");
3756

57+
int num_images = 0;
3858
file.read(reinterpret_cast<char*>(&num_images), 4);
3959
num_images = __builtin_bswap32(num_images);
4060

61+
size_t rows = 0;
4162
file.read(reinterpret_cast<char*>(&rows), 4);
4263
rows = __builtin_bswap32(rows);
4364

65+
size_t cols = 0;
4466
file.read(reinterpret_cast<char*>(&cols), 4);
4567
cols = __builtin_bswap32(cols);
4668

47-
vector<vector<uint8_t>> images(num_images, vector<uint8_t>(rows * cols));
69+
std::vector<Tensor> images;
4870
for (int i = 0; i < num_images; i++) {
49-
file.read(reinterpret_cast<char*>(images[i].data()), rows * cols);
71+
std::vector<float> buffer(rows * cols);
72+
file.read(reinterpret_cast<char*>(buffer.data()), rows * cols);
73+
74+
//@@ temporary
75+
images.push_back(Tensor({rows * cols, 1}, buffer) / 255.0);
5076
}
51-
5277
file.close();
5378
return images;
5479
}
5580

81+
MNISTDataset::MNISTDataset(std::string labels_path, std::string images_path){
82+
labels = load_mnist_labels(labels_path);
83+
images = load_mnist_images(images_path);
84+
85+
}
5686

57-
size_t MNISTDataset::size(){
5887

88+
size_t MNISTDataset::size(){
89+
return 0;
5990
}
6091
std::pair<Tensor, int> MNISTDataset::get_item(size_t index){
6192

6293
}
6394

95+
static char pixel_to_char(float pixel_value) {
96+
if (pixel_value > 0.8f) return '#'; // Dark pixel
97+
else if (pixel_value > 0.6f) return 'O'; // Medium-dark pixel
98+
else if (pixel_value > 0.4f) return '+'; // Medium pixel
99+
else if (pixel_value > 0.2f) return '.'; // Light pixel
100+
else return ' '; // Very light pixel (almost white)
101+
}
102+
103+
void TensorSANN::print_mnist_index(int index){
104+
auto image = images[index].data;
105+
int label = labels[index];
106+
std::cout << "Index: " << index << " Label: " << label << std::endl;
107+
108+
for (int i = 0; i < MINIST_IMG_DIM; i++) {
109+
for (int j = 0; j < MINIST_IMG_DIM; j++) {
110+
printf("%c ", pixel_to_char(image[i * MINIST_IMG_DIM + j]));
111+
}
112+
printf("\n");
113+
}
114+
115+
}
116+
}
64117

65118

66119

0 commit comments

Comments
 (0)