|
| 1 | +// Usage: |
| 2 | +// convert_3d_data input_image_file input_label_file output_db_file |
| 3 | +// Codes are disigned for binary files including data and label. You can modify |
| 4 | +// the condition if information for arranging training data is not the same with |
| 5 | +// category and pose of object. |
| 6 | +#include <fstream> // NOLINT(readability/streams) |
| 7 | +#include <string> |
| 8 | +#include "caffe/proto/caffe.pb.h" |
| 9 | +#include "caffe/util/math_functions.hpp" |
| 10 | +#include "glog/logging.h" |
| 11 | +#include "google/protobuf/text_format.h" |
| 12 | +#ifdef USE_LEVELDB |
| 13 | +#include "leveldb/db.h" |
| 14 | +#include "math.h" |
| 15 | +#include "stdint.h" |
| 16 | + |
| 17 | +uint32_t swap_endian(uint32_t val) { |
| 18 | + val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); |
| 19 | + return (val << 16) | (val >> 16); |
| 20 | +} |
| 21 | + |
| 22 | +void read_image(std::ifstream* image_file, std::ifstream* label_file, |
| 23 | + uint32_t index, uint32_t rows, uint32_t cols, |
| 24 | + char* pixels, char* label_temp, signed char* label, int rgb_use) { |
| 25 | + if (rgb_use == 0) { |
| 26 | + image_file->seekg(index * rows * cols + 16); |
| 27 | + image_file->read(pixels, rows * cols); |
| 28 | + label_file->seekg(index * 4 + 8); // 4 = 1 catory label+3 coordinate label |
| 29 | + label_file->read(label_temp, 4); |
| 30 | + for (int i = 0; i < 4; i++) |
| 31 | + *(label+i) = (signed char)*(label_temp+i); |
| 32 | + } else { |
| 33 | + image_file->seekg(3 * index * rows * cols + 16); |
| 34 | + image_file->read(pixels, 3 * rows * cols); |
| 35 | + label_file->seekg(index * 4 + 8); // 4 = 1 catory label+3 coordinate label |
| 36 | + label_file->read(label_temp, 4); |
| 37 | + for (int i = 0; i < 4; i++) |
| 38 | + *(label+i) = (signed char)*(label_temp+i); |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +void convert_dataset(const char* image_filename, const char* label_filename, |
| 43 | + const char* db_filename, |
| 44 | + const char* class_number, const char* rgb_use) { |
| 45 | + int rgb_use1 = atoi(rgb_use); |
| 46 | + int class_num = atoi(class_number); |
| 47 | + // Open files |
| 48 | + std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); |
| 49 | + std::ifstream label_file(label_filename, std::ios::in | std::ios::binary); |
| 50 | + CHECK(image_file) << "Unable to open file " << image_filename; |
| 51 | + CHECK(label_file) << "Unable to open file " << label_filename; |
| 52 | + // Read the magic and the meta data |
| 53 | + uint32_t magic; |
| 54 | + uint32_t num_items; |
| 55 | + uint32_t num_labels; |
| 56 | + uint32_t rows; |
| 57 | + uint32_t cols; |
| 58 | + |
| 59 | + image_file.read(reinterpret_cast<char*>(&magic), 4); |
| 60 | + magic = swap_endian(magic); |
| 61 | + CHECK_EQ(magic, 2051) << "Incorrect image file magic."; |
| 62 | + label_file.read(reinterpret_cast<char*>(&magic), 4); |
| 63 | + magic = swap_endian(magic); |
| 64 | + CHECK_EQ(magic, 2050) << "Incorrect label file magic."; |
| 65 | + image_file.read(reinterpret_cast<char*>(&num_items), 4); |
| 66 | + num_items = swap_endian(num_items); |
| 67 | + label_file.read(reinterpret_cast<char*>(&num_labels), 4); |
| 68 | + num_labels = swap_endian(num_labels); |
| 69 | + CHECK_EQ(num_items, num_labels); |
| 70 | + image_file.read(reinterpret_cast<char*>(&rows), 4); |
| 71 | + rows = swap_endian(rows); |
| 72 | + image_file.read(reinterpret_cast<char*>(&cols), 4); |
| 73 | + cols = swap_endian(cols); |
| 74 | + |
| 75 | + // Open leveldb |
| 76 | + leveldb::DB* db; |
| 77 | + leveldb::Options options; |
| 78 | + options.create_if_missing = true; |
| 79 | + options.error_if_exists = true; |
| 80 | + leveldb::Status status = leveldb::DB::Open( |
| 81 | + options, db_filename, &db); |
| 82 | + CHECK(status.ok()) << "Failed to open leveldb " << db_filename |
| 83 | + << ". Is it already existing?"; |
| 84 | + |
| 85 | + char* label_temp = new char[4]; // label for unsigned char* |
| 86 | + signed char* label_i = new signed char[4]; // label for triplet |
| 87 | + signed char* label_j = new signed char[4]; |
| 88 | + signed char* label_k = new signed char[4]; |
| 89 | + signed char* label_l = new signed char[4]; // label for pair wise |
| 90 | + signed char* label_m = new signed char[4]; |
| 91 | + int db_size; |
| 92 | + if (rgb_use1 == 0) |
| 93 | + db_size = rows * cols; |
| 94 | + else |
| 95 | + db_size = 3 * rows * cols; |
| 96 | + char* pixels1 = new char[db_size]; |
| 97 | + char* pixels2 = new char[db_size]; |
| 98 | + char* pixels3 = new char[db_size]; |
| 99 | + char* pixels4 = new char[db_size]; |
| 100 | + char* pixels5 = new char[db_size]; |
| 101 | + const int kMaxKeyLength = 10; |
| 102 | + char key[kMaxKeyLength]; |
| 103 | + std::string value; |
| 104 | + caffe::Datum datum; |
| 105 | + if (rgb_use1 == 0) |
| 106 | + datum.set_channels(1); |
| 107 | + else |
| 108 | + datum.set_channels(3); |
| 109 | + datum.set_height(rows); |
| 110 | + datum.set_width(cols); |
| 111 | + LOG(INFO) << "A total of " << num_items << " items."; |
| 112 | + LOG(INFO) << "Rows: " << rows << " Cols: " << cols; |
| 113 | + int counter = 0; |
| 114 | + // This codes selecting 1 positive sample and 3 negative samples for a triplet |
| 115 | + // set. We randomly select data and decide whether concatenating data set to |
| 116 | + // DB file according to labels. |
| 117 | + for (unsigned int times = 0; times < 10; ++times) { |
| 118 | + // iteration in the samples of all class |
| 119 | + for (unsigned int itemid = 0; itemid < num_items/class_num; ++itemid) { |
| 120 | + // iteration in the samples in one class |
| 121 | + for (unsigned int class_ind = 0; class_ind < class_num; ++class_ind) { |
| 122 | + // use reference sample one by one at each iteration |
| 123 | + int i = itemid % num_items + class_ind*num_items/class_num; |
| 124 | + int j = caffe::caffe_rng_rand() % num_items; // pick triplet groups |
| 125 | + int k = caffe::caffe_rng_rand() % num_items; |
| 126 | + int l = caffe::caffe_rng_rand() % num_items; // pick pair wise groups |
| 127 | + int m = caffe::caffe_rng_rand() % num_items; |
| 128 | + read_image(&image_file, &label_file, i, rows, cols, // read triplet |
| 129 | + pixels1, label_temp, label_i, rgb_use1); |
| 130 | + read_image(&image_file, &label_file, j, rows, cols, |
| 131 | + pixels2, label_temp, label_j, rgb_use1); |
| 132 | + read_image(&image_file, &label_file, k, rows, cols, |
| 133 | + pixels3, label_temp, label_k, rgb_use1); |
| 134 | + read_image(&image_file, &label_file, l, rows, cols, // read pair wise |
| 135 | + pixels4, label_temp, label_l, rgb_use1); |
| 136 | + read_image(&image_file, &label_file, m, rows, cols, |
| 137 | + pixels5, label_temp, label_m, rgb_use1); |
| 138 | + |
| 139 | + bool pair_pass = false; |
| 140 | + bool triplet1_pass = false; |
| 141 | + bool triplet2_pass = false; |
| 142 | + bool triplet3_class_same = false; |
| 143 | + bool triplet3_pass = false; |
| 144 | + |
| 145 | + int ij_diff_x = static_cast<int>(*(label_i+1)-*(label_j+1)); |
| 146 | + int ij_diff_y = static_cast<int>(*(label_i+2)-*(label_j+2)); |
| 147 | + int ij_diff_z = static_cast<int>(*(label_i+3)-*(label_j+3)); |
| 148 | + int im_diff_x = static_cast<int>(*(label_i+1)-*(label_m+1)); |
| 149 | + int im_diff_y = static_cast<int>(*(label_i+2)-*(label_m+2)); |
| 150 | + int im_diff_z = static_cast<int>(*(label_i+3)-*(label_m+3)); |
| 151 | + |
| 152 | + int ij_x = ij_diff_x*ij_diff_x; |
| 153 | + int ij_y = ij_diff_y*ij_diff_y; |
| 154 | + int ij_z = ij_diff_z*ij_diff_z; |
| 155 | + int im_x = im_diff_x*im_diff_x; |
| 156 | + int im_y = im_diff_y*im_diff_y; |
| 157 | + int im_z = im_diff_z*im_diff_z; |
| 158 | + |
| 159 | + float dist_ij = std::sqrt(ij_x + ij_y + ij_z); |
| 160 | + float dist_im = std::sqrt(im_x + im_y + im_z); |
| 161 | + // Arrange training data according to conditionals including category |
| 162 | + // and pose of synthetic data, dist_* could be ignored if you |
| 163 | + // only concentrate on category. |
| 164 | + if (*label_i == *label_j && dist_ij < 100/3 && dist_ij != 0) |
| 165 | + pair_pass = true; |
| 166 | + if (pair_pass && (*label_i != *label_k)) |
| 167 | + triplet1_pass = true; |
| 168 | + if (pair_pass && (*label_i != *label_l)) |
| 169 | + triplet2_pass = true; |
| 170 | + if (pair_pass && (*label_i == *label_m)) |
| 171 | + triplet3_class_same = true; |
| 172 | + if (triplet3_class_same && dist_im > 100/3) |
| 173 | + triplet3_pass = true; |
| 174 | + if (pair_pass && triplet1_pass && triplet2_pass && triplet3_pass) { |
| 175 | + datum.set_data(pixels1, db_size); // set data |
| 176 | + datum.set_label(static_cast<int>(*label_i)); |
| 177 | + datum.SerializeToString(&value); |
| 178 | + snprintf(key, kMaxKeyLength, "%08d", counter); |
| 179 | + db->Put(leveldb::WriteOptions(), std::string(key), value); |
| 180 | + counter++; |
| 181 | + datum.set_data(pixels2, db_size); // set data |
| 182 | + datum.set_label(static_cast<int>(*label_j)); |
| 183 | + datum.SerializeToString(&value); |
| 184 | + snprintf(key, kMaxKeyLength, "%08d", counter); |
| 185 | + db->Put(leveldb::WriteOptions(), std::string(key), value); |
| 186 | + counter++; |
| 187 | + datum.set_data(pixels3, db_size); // set data |
| 188 | + datum.set_label(static_cast<int>(*label_k)); |
| 189 | + datum.SerializeToString(&value); |
| 190 | + snprintf(key, kMaxKeyLength, "%08d", counter); |
| 191 | + db->Put(leveldb::WriteOptions(), std::string(key), value); |
| 192 | + counter++; |
| 193 | + datum.set_data(pixels4, db_size); // set data |
| 194 | + datum.set_label(static_cast<int>(*label_l)); |
| 195 | + datum.SerializeToString(&value); |
| 196 | + snprintf(key, kMaxKeyLength, "%08d", counter); |
| 197 | + db->Put(leveldb::WriteOptions(), std::string(key), value); |
| 198 | + counter++; |
| 199 | + datum.set_data(pixels5, db_size); // set data |
| 200 | + datum.set_label(static_cast<int>(*label_m)); |
| 201 | + datum.SerializeToString(&value); |
| 202 | + snprintf(key, kMaxKeyLength, "%08d", counter); |
| 203 | + db->Put(leveldb::WriteOptions(), std::string(key), value); |
| 204 | + counter++; |
| 205 | + } else { |
| 206 | + class_ind--; |
| 207 | + } |
| 208 | + } // iteration in the samples of all class |
| 209 | + } // iteration in the samples in one class |
| 210 | + } // iteration in times |
| 211 | + delete db; |
| 212 | + delete pixels1; |
| 213 | + delete pixels2; |
| 214 | + delete pixels3; |
| 215 | + delete pixels4; |
| 216 | + delete pixels5; |
| 217 | +} |
| 218 | + |
| 219 | +int main(int argc, char** argv) { |
| 220 | + if (argc != 6) { |
| 221 | + printf("This script converts the dataset to the leveldb format used\n" |
| 222 | + "by caffe to train a triplet network.\n" |
| 223 | + "Usage:\n" |
| 224 | + " convert_3d_data input_image_file input_label_file " |
| 225 | + "output_db_file class_number rgb_use \n"); |
| 226 | + } else { |
| 227 | + google::InitGoogleLogging(argv[0]); |
| 228 | + convert_dataset(argv[1], argv[2], argv[3], argv[4], argv[5]); |
| 229 | + } |
| 230 | + return 0; |
| 231 | +} |
| 232 | +#else |
| 233 | +int main(int argc, char** argv) { |
| 234 | + LOG(FATAL) << "This example requires LevelDB; compile with USE_LEVELDB."; |
| 235 | +} |
| 236 | +#endif // USE_LEVELDB |
0 commit comments