Skip to content

Commit ca5a8cd

Browse files
committed
Implement CNN Triplet training.
1 parent 923e7e8 commit ca5a8cd

18 files changed

Lines changed: 2920 additions & 1 deletion

cmake/Modules/FindvecLib.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ find_package_handle_standard_args(vecLib DEFAULT_MSG vecLib_INCLUDE_DIR)
2323

2424
if(VECLIB_FOUND)
2525
if(vecLib_INCLUDE_DIR MATCHES "^/System/Library/Frameworks/vecLib.framework.*")
26-
set(vecLib_LINKER_LIBS -lcblas "-framework vecLib")
26+
set(vecLib_LINKER_LIBS -lcblas "-framework Accelerate")
2727
message(STATUS "Found standalone vecLib.framework")
2828
else()
2929
set(vecLib_LINKER_LIBS -lcblas "-framework Accelerate")
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// Usage:
2+
// convert_3d_data input_image_file input_label_file output_db_file
3+
// Codes are disigned for binary files including data and label. You can modify
4+
// the condition if information for arranging training data is not the same with
5+
// category and pose of object.
6+
#include <fstream> // NOLINT(readability/streams)
7+
#include <string>
8+
#include "caffe/proto/caffe.pb.h"
9+
#include "caffe/util/math_functions.hpp"
10+
#include "glog/logging.h"
11+
#include "google/protobuf/text_format.h"
12+
#ifdef USE_LEVELDB
13+
#include "leveldb/db.h"
14+
#include "math.h"
15+
#include "stdint.h"
16+
17+
uint32_t swap_endian(uint32_t val) {
18+
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
19+
return (val << 16) | (val >> 16);
20+
}
21+
22+
void read_image(std::ifstream* image_file, std::ifstream* label_file,
23+
uint32_t index, uint32_t rows, uint32_t cols,
24+
char* pixels, char* label_temp, signed char* label, int rgb_use) {
25+
if (rgb_use == 0) {
26+
image_file->seekg(index * rows * cols + 16);
27+
image_file->read(pixels, rows * cols);
28+
label_file->seekg(index * 4 + 8); // 4 = 1 catory label+3 coordinate label
29+
label_file->read(label_temp, 4);
30+
for (int i = 0; i < 4; i++)
31+
*(label+i) = (signed char)*(label_temp+i);
32+
} else {
33+
image_file->seekg(3 * index * rows * cols + 16);
34+
image_file->read(pixels, 3 * rows * cols);
35+
label_file->seekg(index * 4 + 8); // 4 = 1 catory label+3 coordinate label
36+
label_file->read(label_temp, 4);
37+
for (int i = 0; i < 4; i++)
38+
*(label+i) = (signed char)*(label_temp+i);
39+
}
40+
}
41+
42+
void convert_dataset(const char* image_filename, const char* label_filename,
43+
const char* db_filename,
44+
const char* class_number, const char* rgb_use) {
45+
int rgb_use1 = atoi(rgb_use);
46+
int class_num = atoi(class_number);
47+
// Open files
48+
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
49+
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
50+
CHECK(image_file) << "Unable to open file " << image_filename;
51+
CHECK(label_file) << "Unable to open file " << label_filename;
52+
// Read the magic and the meta data
53+
uint32_t magic;
54+
uint32_t num_items;
55+
uint32_t num_labels;
56+
uint32_t rows;
57+
uint32_t cols;
58+
59+
image_file.read(reinterpret_cast<char*>(&magic), 4);
60+
magic = swap_endian(magic);
61+
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
62+
label_file.read(reinterpret_cast<char*>(&magic), 4);
63+
magic = swap_endian(magic);
64+
CHECK_EQ(magic, 2050) << "Incorrect label file magic.";
65+
image_file.read(reinterpret_cast<char*>(&num_items), 4);
66+
num_items = swap_endian(num_items);
67+
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
68+
num_labels = swap_endian(num_labels);
69+
CHECK_EQ(num_items, num_labels);
70+
image_file.read(reinterpret_cast<char*>(&rows), 4);
71+
rows = swap_endian(rows);
72+
image_file.read(reinterpret_cast<char*>(&cols), 4);
73+
cols = swap_endian(cols);
74+
75+
// Open leveldb
76+
leveldb::DB* db;
77+
leveldb::Options options;
78+
options.create_if_missing = true;
79+
options.error_if_exists = true;
80+
leveldb::Status status = leveldb::DB::Open(
81+
options, db_filename, &db);
82+
CHECK(status.ok()) << "Failed to open leveldb " << db_filename
83+
<< ". Is it already existing?";
84+
85+
char* label_temp = new char[4]; // label for unsigned char*
86+
signed char* label_i = new signed char[4]; // label for triplet
87+
signed char* label_j = new signed char[4];
88+
signed char* label_k = new signed char[4];
89+
signed char* label_l = new signed char[4]; // label for pair wise
90+
signed char* label_m = new signed char[4];
91+
int db_size;
92+
if (rgb_use1 == 0)
93+
db_size = rows * cols;
94+
else
95+
db_size = 3 * rows * cols;
96+
char* pixels1 = new char[db_size];
97+
char* pixels2 = new char[db_size];
98+
char* pixels3 = new char[db_size];
99+
char* pixels4 = new char[db_size];
100+
char* pixels5 = new char[db_size];
101+
const int kMaxKeyLength = 10;
102+
char key[kMaxKeyLength];
103+
std::string value;
104+
caffe::Datum datum;
105+
if (rgb_use1 == 0)
106+
datum.set_channels(1);
107+
else
108+
datum.set_channels(3);
109+
datum.set_height(rows);
110+
datum.set_width(cols);
111+
LOG(INFO) << "A total of " << num_items << " items.";
112+
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
113+
int counter = 0;
114+
// This codes selecting 1 positive sample and 3 negative samples for a triplet
115+
// set. We randomly select data and decide whether concatenating data set to
116+
// DB file according to labels.
117+
for (unsigned int times = 0; times < 10; ++times) {
118+
// iteration in the samples of all class
119+
for (unsigned int itemid = 0; itemid < num_items/class_num; ++itemid) {
120+
// iteration in the samples in one class
121+
for (unsigned int class_ind = 0; class_ind < class_num; ++class_ind) {
122+
// use reference sample one by one at each iteration
123+
int i = itemid % num_items + class_ind*num_items/class_num;
124+
int j = caffe::caffe_rng_rand() % num_items; // pick triplet groups
125+
int k = caffe::caffe_rng_rand() % num_items;
126+
int l = caffe::caffe_rng_rand() % num_items; // pick pair wise groups
127+
int m = caffe::caffe_rng_rand() % num_items;
128+
read_image(&image_file, &label_file, i, rows, cols, // read triplet
129+
pixels1, label_temp, label_i, rgb_use1);
130+
read_image(&image_file, &label_file, j, rows, cols,
131+
pixels2, label_temp, label_j, rgb_use1);
132+
read_image(&image_file, &label_file, k, rows, cols,
133+
pixels3, label_temp, label_k, rgb_use1);
134+
read_image(&image_file, &label_file, l, rows, cols, // read pair wise
135+
pixels4, label_temp, label_l, rgb_use1);
136+
read_image(&image_file, &label_file, m, rows, cols,
137+
pixels5, label_temp, label_m, rgb_use1);
138+
139+
bool pair_pass = false;
140+
bool triplet1_pass = false;
141+
bool triplet2_pass = false;
142+
bool triplet3_class_same = false;
143+
bool triplet3_pass = false;
144+
145+
int ij_diff_x = static_cast<int>(*(label_i+1)-*(label_j+1));
146+
int ij_diff_y = static_cast<int>(*(label_i+2)-*(label_j+2));
147+
int ij_diff_z = static_cast<int>(*(label_i+3)-*(label_j+3));
148+
int im_diff_x = static_cast<int>(*(label_i+1)-*(label_m+1));
149+
int im_diff_y = static_cast<int>(*(label_i+2)-*(label_m+2));
150+
int im_diff_z = static_cast<int>(*(label_i+3)-*(label_m+3));
151+
152+
int ij_x = ij_diff_x*ij_diff_x;
153+
int ij_y = ij_diff_y*ij_diff_y;
154+
int ij_z = ij_diff_z*ij_diff_z;
155+
int im_x = im_diff_x*im_diff_x;
156+
int im_y = im_diff_y*im_diff_y;
157+
int im_z = im_diff_z*im_diff_z;
158+
159+
float dist_ij = std::sqrt(ij_x + ij_y + ij_z);
160+
float dist_im = std::sqrt(im_x + im_y + im_z);
161+
// Arrange training data according to conditionals including category
162+
// and pose of synthetic data, dist_* could be ignored if you
163+
// only concentrate on category.
164+
if (*label_i == *label_j && dist_ij < 100/3 && dist_ij != 0)
165+
pair_pass = true;
166+
if (pair_pass && (*label_i != *label_k))
167+
triplet1_pass = true;
168+
if (pair_pass && (*label_i != *label_l))
169+
triplet2_pass = true;
170+
if (pair_pass && (*label_i == *label_m))
171+
triplet3_class_same = true;
172+
if (triplet3_class_same && dist_im > 100/3)
173+
triplet3_pass = true;
174+
if (pair_pass && triplet1_pass && triplet2_pass && triplet3_pass) {
175+
datum.set_data(pixels1, db_size); // set data
176+
datum.set_label(static_cast<int>(*label_i));
177+
datum.SerializeToString(&value);
178+
snprintf(key, kMaxKeyLength, "%08d", counter);
179+
db->Put(leveldb::WriteOptions(), std::string(key), value);
180+
counter++;
181+
datum.set_data(pixels2, db_size); // set data
182+
datum.set_label(static_cast<int>(*label_j));
183+
datum.SerializeToString(&value);
184+
snprintf(key, kMaxKeyLength, "%08d", counter);
185+
db->Put(leveldb::WriteOptions(), std::string(key), value);
186+
counter++;
187+
datum.set_data(pixels3, db_size); // set data
188+
datum.set_label(static_cast<int>(*label_k));
189+
datum.SerializeToString(&value);
190+
snprintf(key, kMaxKeyLength, "%08d", counter);
191+
db->Put(leveldb::WriteOptions(), std::string(key), value);
192+
counter++;
193+
datum.set_data(pixels4, db_size); // set data
194+
datum.set_label(static_cast<int>(*label_l));
195+
datum.SerializeToString(&value);
196+
snprintf(key, kMaxKeyLength, "%08d", counter);
197+
db->Put(leveldb::WriteOptions(), std::string(key), value);
198+
counter++;
199+
datum.set_data(pixels5, db_size); // set data
200+
datum.set_label(static_cast<int>(*label_m));
201+
datum.SerializeToString(&value);
202+
snprintf(key, kMaxKeyLength, "%08d", counter);
203+
db->Put(leveldb::WriteOptions(), std::string(key), value);
204+
counter++;
205+
} else {
206+
class_ind--;
207+
}
208+
} // iteration in the samples of all class
209+
} // iteration in the samples in one class
210+
} // iteration in times
211+
delete db;
212+
delete pixels1;
213+
delete pixels2;
214+
delete pixels3;
215+
delete pixels4;
216+
delete pixels5;
217+
}
218+
219+
int main(int argc, char** argv) {
220+
if (argc != 6) {
221+
printf("This script converts the dataset to the leveldb format used\n"
222+
"by caffe to train a triplet network.\n"
223+
"Usage:\n"
224+
" convert_3d_data input_image_file input_label_file "
225+
"output_db_file class_number rgb_use \n");
226+
} else {
227+
google::InitGoogleLogging(argv[0]);
228+
convert_dataset(argv[1], argv[2], argv[3], argv[4], argv[5]);
229+
}
230+
return 0;
231+
}
232+
#else
233+
int main(int argc, char** argv) {
234+
LOG(FATAL) << "This example requires LevelDB; compile with USE_LEVELDB.";
235+
}
236+
#endif // USE_LEVELDB
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/usr/bin/env sh
2+
# This script converts the mnist data into leveldb format.
3+
4+
EXAMPLES=./build/examples/triplet
5+
DATA=./data/linemod
6+
7+
echo "Creating leveldb..."
8+
9+
# this script taking data which consist of 6 categories to leveldb format for
10+
# tripplet training.
11+
12+
rm -rf ./examples/triplet/3d_triplet_train_leveldb
13+
rm -rf ./examples/triplet/3d_triplet_test_leveldb
14+
15+
$EXAMPLES/convert_3d_triplet_data.bin \
16+
$DATA/binary_image_train \
17+
$DATA/binary_label_train \
18+
./examples/triplet/3d_triplet_train_leveldb \
19+
6 \
20+
0
21+
$EXAMPLES/convert_3d_triplet_data.bin \
22+
$DATA/binary_image_test \
23+
$DATA/binary_label_test \
24+
./examples/triplet/3d_triplet_test_leveldb \
25+
6 \
26+
0
27+
echo "Done."
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
name: "multipie_triplet"
2+
input: "data"
3+
input_dim: 1
4+
input_dim: 1
5+
input_dim: 75
6+
input_dim: 65
7+
layer {
8+
name: "conv1"
9+
type: "Convolution"
10+
bottom: "data"
11+
top: "conv1"
12+
param {
13+
lr_mult: 1
14+
}
15+
param {
16+
lr_mult: 2
17+
}
18+
convolution_param {
19+
num_output: 16
20+
kernel_size: 8
21+
stride: 1
22+
}
23+
}
24+
layer {
25+
name: "pool1"
26+
type: "Pooling"
27+
bottom: "conv1"
28+
top: "pool1"
29+
pooling_param {
30+
pool: MAX
31+
kernel_size: 2
32+
stride: 2
33+
}
34+
}
35+
layer {
36+
name: "relu1"
37+
type: "ReLU"
38+
bottom: "pool1"
39+
top: "pool1"
40+
}
41+
layer {
42+
name: "conv2"
43+
type: "Convolution"
44+
bottom: "pool1"
45+
top: "conv2"
46+
param {
47+
lr_mult: 1
48+
}
49+
param {
50+
lr_mult: 2
51+
}
52+
convolution_param {
53+
num_output: 7
54+
kernel_size: 5
55+
stride: 1
56+
}
57+
}
58+
layer {
59+
name: "pool2"
60+
type: "Pooling"
61+
bottom: "conv2"
62+
top: "pool2"
63+
pooling_param {
64+
pool: MAX
65+
kernel_size: 2
66+
stride: 2
67+
}
68+
}
69+
layer {
70+
name: "relu2"
71+
type: "ReLU"
72+
bottom: "pool2"
73+
top: "pool2"
74+
}
75+
layer {
76+
name: "ip1"
77+
type: "InnerProduct"
78+
bottom: "pool2"
79+
top: "ip1"
80+
param {
81+
lr_mult: 1
82+
}
83+
param {
84+
lr_mult: 2
85+
}
86+
inner_product_param {
87+
num_output: 256
88+
}
89+
}
90+
layer {
91+
name: "relu3"
92+
type: "ReLU"
93+
bottom: "ip1"
94+
top: "ip1"
95+
}
96+
layer {
97+
name: "feat"
98+
type: "InnerProduct"
99+
bottom: "ip1"
100+
top: "feat"
101+
param {
102+
lr_mult: 1
103+
}
104+
param {
105+
lr_mult: 2
106+
}
107+
inner_product_param {
108+
num_output: 150
109+
}
110+
}

0 commit comments

Comments
 (0)