diff --git a/examples/mnist/lenet_import_conv_pool.prototxt b/examples/mnist/lenet_import_conv_pool.prototxt new file mode 100644 index 00000000000..5e2b7886e22 --- /dev/null +++ b/examples/mnist/lenet_import_conv_pool.prototxt @@ -0,0 +1,30 @@ +layers { + name: "conv" + type: CONVOLUTION + bottom: "${bottom}" + top: "conv" + blobs_lr: 1 + blobs_lr: 2 + convolution_param { + num_output: ${num_output} + kernel_size: 5 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "pool" + type: POOLING + bottom: "conv" + top: "pool" + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} diff --git a/examples/mnist/lenet_import_solver.prototxt b/examples/mnist/lenet_import_solver.prototxt new file mode 100644 index 00000000000..c332567f37b --- /dev/null +++ b/examples/mnist/lenet_import_solver.prototxt @@ -0,0 +1,25 @@ +# The train/test net protocol buffer definition +net: "examples/mnist/lenet_import_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.01 +momentum: 0.9 +weight_decay: 0.0005 +# The learning rate policy +lr_policy: "inv" +gamma: 0.0001 +power: 0.75 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet" +# solver mode: CPU or GPU +solver_mode: GPU diff --git a/examples/mnist/lenet_import_train_test.prototxt b/examples/mnist/lenet_import_train_test.prototxt new file mode 100644 index 00000000000..4ab86e0dc49 --- /dev/null +++ b/examples/mnist/lenet_import_train_test.prototxt @@ -0,0 +1,104 @@ +name: "LeNet" +layers { + name: "mnist" + type: DATA + top: "data" + top: "label" + data_param { + source: "examples/mnist/mnist_train_lmdb" + backend: LMDB + batch_size: 64 + } + transform_param { + scale: 0.00390625 + } + include: { phase: TRAIN } +} +layers { + name: "mnist" + type: DATA + top: "data" + top: "label" + data_param { + source: "examples/mnist/mnist_test_lmdb" + backend: LMDB + batch_size: 100 + } + transform_param { + scale: 0.00390625 + } + include: { phase: TEST } +} +layers { + name: "cp1" + type: IMPORT + import_param { + net: "examples/mnist/lenet_import_conv_pool.prototxt" + var { name: "bottom" value: "/data" } + var { name: "num_output" value: "20" } + } +} +layers { + name: "cp2" + type: IMPORT + import_param { + net: "examples/mnist/lenet_import_conv_pool.prototxt" + var { name: "bottom" value: "../cp1/pool" } + var { name: "num_output" value: "50" } + } +} +layers { + name: "ip1" + type: INNER_PRODUCT + bottom: "cp2/pool" + top: "ip1" + blobs_lr: 1 + blobs_lr: 2 + inner_product_param { + num_output: 500 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "ip1" + top: "ip1" +} +layers { + name: "ip2" + type: INNER_PRODUCT + bottom: "ip1" + top: "ip2" + blobs_lr: 1 + blobs_lr: 2 + inner_product_param { + num_output: 10 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "ip2" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "ip2" + bottom: "label" + top: "loss" +} diff --git a/examples/mnist/train_lenet_import.sh b/examples/mnist/train_lenet_import.sh new file mode 100755 index 00000000000..6387228d368 --- /dev/null +++ b/examples/mnist/train_lenet_import.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +GLOG_logtostderr=0 GLOG_log_dir=examples/mnist/ ./build/tools/caffe train --solver=examples/mnist/lenet_import_solver.prototxt --gpu=1 diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 1d06dc45533..a5229f1df34 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -182,6 +182,13 @@ class Net { /// @brief Get misc parameters, e.g. the LR multiplier and weight decay. void GetLearningRateAndWeightDecay(); + // @brief Loads imports, for modular network definitions + static void LoadImports(const NetParameter& source, NetParameter* target); + static void LoadImports(const NetParameter& source, NetParameter* target, + const string& pwd); + // @brief Resolves a layer or blob name, e.g. "../data" + static string ResolveImportName(const string& path, const string& pwd); + /// @brief Individual layers in the net vector > > layers_; vector layer_names_; diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index e518979a75b..9ca84d1582d 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -52,6 +52,8 @@ inline void MakeTempDir(string* temp_dirname) { delete temp_dirname_cstr; } +string ReadFile(const string& filename); + bool ReadProtoFromTextFile(const char* filename, Message* proto); inline bool ReadProtoFromTextFile(const string& filename, Message* proto) { diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 21ab15fd31b..b011676f979 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -1,3 +1,6 @@ +#include +#include + #include #include #include @@ -16,6 +19,8 @@ #include "caffe/test/test_caffe_main.hpp" +using boost::replace_all; + namespace caffe { template @@ -32,10 +37,14 @@ Net::Net(const string& param_file) { template void Net::Init(const NetParameter& in_param) { + // Load import layers + NetParameter expanded(in_param); + LoadImports(in_param, &expanded); + // Filter layers based on their include/exclude rules and // the current NetState. NetParameter filtered_param; - FilterNet(in_param, &filtered_param); + FilterNet(expanded, &filtered_param); LOG(INFO) << "Initializing net from parameters: " << std::endl << filtered_param.DebugString(); // Create a copy of filtered_param with splits added where necessary. @@ -462,6 +471,66 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, } } + +template +void Net::LoadImports(const NetParameter& source, NetParameter* target) { + target->CopyFrom(source); + target->clear_layers(); + LoadImports(source, target, ""); +} + +template +void Net::LoadImports(const NetParameter& source, NetParameter* target, + const string& pwd) { + for (int i = 0; i < source.layers_size(); ++i) { + if (source.layers(i).type() == LayerParameter_LayerType_IMPORT) { + const LayerParameter& layer = source.layers(i); + CHECK(layer.has_import_param()) << "Missing import_param"; + const ImportParameter& import = layer.import_param(); + string proto = ReadFile(import.net()); + // Replace variables and references + for (int j = 0; j < import.var_size(); ++j) { + const Pair& p = import.var(j); + replace_all(proto, "${" + p.name() + "}", p.value()); + } + NetParameter net; + bool parse = google::protobuf::TextFormat::ParseFromString(proto, &net); + CHECK(parse) << "Failed to parse NetParameter file: " << import.net(); + CHECK(layer.has_name() && layer.name().length() > 0) + << "Import layer must have a name"; + LoadImports(net, target, ResolveImportName(layer.name(), pwd)); + } else { + LayerParameter *t = target->add_layers(); + t->CopyFrom(source.layers(i)); + t->set_name(ResolveImportName(t->name(), pwd)); + for (int j = 0; j < source.layers(i).top_size(); ++j) + t->set_top(j, ResolveImportName(source.layers(i).top(j), pwd)); + for (int j = 0; j < source.layers(i).bottom_size(); ++j) + t->set_bottom(j, ResolveImportName(source.layers(i).bottom(j), pwd)); + } + } +} + +template +string Net::ResolveImportName(const string& path, const string& pwd) { + CHECK(!boost::starts_with(pwd, "/") && !boost::ends_with(pwd, "/")); + if (boost::starts_with(path, "/")) + return path.substr(1, path.size() - 1); + string cpath = path; + string cpwd = pwd; + while (boost::starts_with(cpath, "../")) { + cpath = cpath.substr(3, cpath.size() - 3); + size_t i = cpwd.find_last_of('/'); + cpwd = i == string::npos ? "" : cpwd.substr(0, i); + } + if (!cpwd.size()) + return cpath; + if (!cpath.size() || cpath == ".") + return cpwd; + return cpwd + '/' + cpath; +} + + template void Net::GetLearningRateAndWeightDecay() { LOG(INFO) << "Collecting Learning Rate and Weight Decay."; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 8cc18a5fd20..f5bd5ca6928 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -246,9 +246,10 @@ message LayerParameter { HINGE_LOSS = 28; IM2COL = 11; IMAGE_DATA = 12; + IMPORT = 51; INFOGAIN_LOSS = 13; INNER_PRODUCT = 14; - LOCAL = 39; + LOCAL = 50; LRN = 15; MEMORY_DATA = 29; MULTINOMIAL_LOGISTIC_LOSS = 16; @@ -308,6 +309,7 @@ message LayerParameter { optional HDF5OutputParameter hdf5_output_param = 14; optional HingeLossParameter hinge_loss_param = 29; optional ImageDataParameter image_data_param = 15; + optional ImportParameter import_param = 103; optional InfogainLossParameter infogain_loss_param = 16; optional InnerProductParameter inner_product_param = 17; optional LocalParameter local_param = 102; @@ -553,6 +555,20 @@ message ImageDataParameter { optional bool mirror = 6 [default = false]; } +message Pair { + required string name = 1; + required string value = 2; +} + +// Message that stores parameters used by ImportLayer +message ImportParameter { + // Proto file to import + required string net = 1; + // Variable names to replace before importing the file. Variables can + // be used in the file in this format: ${name} + repeated Pair var = 2; +} + // Message that stores parameters InfogainLossLayer message InfogainLossParameter { // Specify the infogain matrix source. diff --git a/src/caffe/test/test_data/module.prototxt b/src/caffe/test/test_data/module.prototxt new file mode 100644 index 00000000000..6c2d5359360 --- /dev/null +++ b/src/caffe/test/test_data/module.prototxt @@ -0,0 +1,21 @@ +layers: { + name: 'innerproduct' + type: INNER_PRODUCT + inner_product_param { + num_output: ${num_output} + weight_filler { + type: 'gaussian' + std: 0.01 + } + bias_filler { + type: 'constant' + value: 0 + } + } + blobs_lr: 1. + blobs_lr: 2. + weight_decay: 1. + weight_decay: 0. + bottom: '../data' + top: 'innerproduct' +} diff --git a/src/caffe/test/test_imports.cpp b/src/caffe/test/test_imports.cpp new file mode 100644 index 00000000000..023ae6dce7f --- /dev/null +++ b/src/caffe/test/test_imports.cpp @@ -0,0 +1,87 @@ +#include +#include +#include + +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class ImportsTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitNetFromProtoString(const string& proto) { + NetParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + net_.reset(new Net(param)); + } + + virtual void InitNet() { + string file = CMAKE_SOURCE_DIR "caffe/test/test_data/module.prototxt"; + string proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " num: 5 " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'import' " + " type: IMPORT " + " import_param { " + " net: '" + file + "' " + " var { name: 'num_output' value: '1000' } " + " } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'import/innerproduct' " + " bottom: 'label' " + " top: 'top_loss' " + "} "; + InitNetFromProtoString(proto); + } + + shared_ptr > net_; +}; + +TYPED_TEST_CASE(ImportsTest, TestDtypesAndDevices); + +TYPED_TEST(ImportsTest, ConvPool) { + this->InitNet(); + EXPECT_TRUE(this->net_->has_blob("data")); + EXPECT_TRUE(this->net_->has_blob("label")); + EXPECT_TRUE(this->net_->has_blob("import/innerproduct")); + EXPECT_FALSE(this->net_->has_blob("loss")); +} +} // namespace caffe + diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index 36510d61d40..09d4472c7bc 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -28,6 +28,18 @@ using google::protobuf::io::ZeroCopyOutputStream; using google::protobuf::io::CodedOutputStream; using google::protobuf::Message; +std::string ReadFile(const string& filename) { + std::ifstream in(filename.c_str(), std::ios::in | std::ios::binary); + CHECK(in) << "Failed to read file: " << filename; + std::string contents; + in.seekg(0, std::ios::end); + contents.resize(in.tellg()); + in.seekg(0, std::ios::beg); + in.read(&contents[0], contents.size()); + in.close(); + return contents; +} + bool ReadProtoFromTextFile(const char* filename, Message* proto) { int fd = open(filename, O_RDONLY); CHECK_NE(fd, -1) << "File not found: " << filename;