|
1 | 1 | #include "TensorSANN/data/MNISTDataset.hpp" |
2 | 2 | #include "TensorSANN/utils/Tensor.hpp" |
3 | | -#include <fstream> |
| 3 | + |
4 | 4 | #include <stdexcept> |
| 5 | +#include <iostream> |
| 6 | +#include <fstream> |
| 7 | +#include <vector> |
5 | 8 |
|
6 | 9 | #define MNIST_IMAGE_MNUM 2051 |
7 | 10 | #define MNIST_LABEL_MNUM 2049 |
8 | 11 | #define MINIST_IMG_DIM 28 |
9 | 12 |
|
10 | 13 |
|
| 14 | +namespace TensorSANN{ |
| 15 | + |
| 16 | +static std::vector<unsigned int> load_mnist_labels(std::string path){ |
| 17 | + std::ifstream file(path, std::ios::binary); |
| 18 | + if (!file.is_open()) { |
| 19 | + throw std::invalid_argument("Could not open file."); |
| 20 | + } |
11 | 21 |
|
12 | | -protected: |
13 | | -std::vector<Tensor> images; |
14 | | -std::vector<int> labels; |
| 22 | + int magic_number = 0; |
| 23 | + file.read(reinterpret_cast<char*>(&magic_number), 4); |
| 24 | + magic_number = __builtin_bswap32(magic_number); |
15 | 25 |
|
| 26 | + if (magic_number != MNIST_LABEL_MNUM) throw std::invalid_argument("Invalid magic nubmer."); |
16 | 27 |
|
17 | | -static std::vector<unsigned int> load_mnist_labels(std::string path){ |
| 28 | + int num_labels = 0; |
| 29 | + file.read(reinterpret_cast<char*>(&num_labels), 4); |
| 30 | + num_labels = __builtin_bswap32(num_labels); |
18 | 31 |
|
19 | | - std::ifstream file(path, std::ios::binary); |
20 | | - if (!file.is_open()) { |
21 | | - throw std::invalid_argument("Could not open file: %s", path); |
| 32 | + std::vector<unsigned int> labels; |
| 33 | + for (int i = 0; i < num_labels; i++) { |
| 34 | + uint8_t label; |
| 35 | + file.read(reinterpret_cast<char*>(&label), 1); |
| 36 | + labels.push_back(label); |
22 | 37 | } |
23 | 38 |
|
| 39 | + file.close(); |
| 40 | + |
| 41 | + return labels; |
| 42 | + |
24 | 43 |
|
25 | 44 | } |
26 | | -static std::vector<Tensor> MNISTDataset::load_mnist_images(std::string path){ |
| 45 | +static std::vector<Tensor> load_mnist_images(std::string path){ |
27 | 46 | std::ifstream file(path, std::ios::binary); |
28 | 47 | if (!file.is_open()) { |
29 | | - throw std::invalid_argument("Could not open file: %s", path); |
| 48 | + throw std::invalid_argument("Could not open file."); |
30 | 49 | } |
31 | 50 |
|
32 | 51 | int magic_number = 0; |
33 | 52 | file.read(reinterpret_cast<char*>(&magic_number), 4); |
34 | 53 | magic_number = __builtin_bswap32(magic_number); // Convert from big-endian |
35 | 54 |
|
36 | | - if (magic_number != MNIST_IMAGE_MNUM) throw std::invalid_argume |
| 55 | + if (magic_number != MNIST_IMAGE_MNUM) throw std::invalid_argument("Invalid magic number."); |
37 | 56 |
|
| 57 | + int num_images = 0; |
38 | 58 | file.read(reinterpret_cast<char*>(&num_images), 4); |
39 | 59 | num_images = __builtin_bswap32(num_images); |
40 | 60 |
|
| 61 | + size_t rows = 0; |
41 | 62 | file.read(reinterpret_cast<char*>(&rows), 4); |
42 | 63 | rows = __builtin_bswap32(rows); |
43 | 64 |
|
| 65 | + size_t cols = 0; |
44 | 66 | file.read(reinterpret_cast<char*>(&cols), 4); |
45 | 67 | cols = __builtin_bswap32(cols); |
46 | 68 |
|
47 | | - vector<vector<uint8_t>> images(num_images, vector<uint8_t>(rows * cols)); |
| 69 | + std::vector<Tensor> images; |
48 | 70 | for (int i = 0; i < num_images; i++) { |
49 | | - file.read(reinterpret_cast<char*>(images[i].data()), rows * cols); |
| 71 | + std::vector<float> buffer(rows * cols); |
| 72 | + file.read(reinterpret_cast<char*>(buffer.data()), rows * cols); |
| 73 | + |
| 74 | + //@@ temporary |
| 75 | + images.push_back(Tensor({rows * cols, 1}, buffer) / 255.0); |
50 | 76 | } |
51 | | - |
52 | 77 | file.close(); |
53 | 78 | return images; |
54 | 79 | } |
55 | 80 |
|
| 81 | +MNISTDataset::MNISTDataset(std::string labels_path, std::string images_path){ |
| 82 | + labels = load_mnist_labels(labels_path); |
| 83 | + images = load_mnist_images(images_path); |
| 84 | + |
| 85 | +} |
56 | 86 |
|
57 | | -size_t MNISTDataset::size(){ |
58 | 87 |
|
| 88 | +size_t MNISTDataset::size(){ |
| 89 | + return 0; |
59 | 90 | } |
60 | 91 | std::pair<Tensor, int> MNISTDataset::get_item(size_t index){ |
61 | 92 |
|
62 | 93 | } |
63 | 94 |
|
| 95 | +static char pixel_to_char(float pixel_value) { |
| 96 | + if (pixel_value > 0.8f) return '#'; // Dark pixel |
| 97 | + else if (pixel_value > 0.6f) return 'O'; // Medium-dark pixel |
| 98 | + else if (pixel_value > 0.4f) return '+'; // Medium pixel |
| 99 | + else if (pixel_value > 0.2f) return '.'; // Light pixel |
| 100 | + else return ' '; // Very light pixel (almost white) |
| 101 | +} |
| 102 | + |
| 103 | +void TensorSANN::print_mnist_index(int index){ |
| 104 | + auto image = images[index].data; |
| 105 | + int label = labels[index]; |
| 106 | + std::cout << "Index: " << index << " Label: " << label << std::endl; |
| 107 | + |
| 108 | + for (int i = 0; i < MINIST_IMG_DIM; i++) { |
| 109 | + for (int j = 0; j < MINIST_IMG_DIM; j++) { |
| 110 | + printf("%c ", pixel_to_char(image[i * MINIST_IMG_DIM + j])); |
| 111 | + } |
| 112 | + printf("\n"); |
| 113 | + } |
| 114 | + |
| 115 | +} |
| 116 | +} |
64 | 117 |
|
65 | 118 |
|
66 | 119 |
|
|
0 commit comments