File size: 6,381 Bytes
f5bb0c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
#include <string>
#include <vector>
#include "boost/algorithm/string.hpp"
#include "google/protobuf/text_format.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/io.hpp"
using caffe::Blob;
using caffe::Caffe;
using caffe::Datum;
using caffe::Net;
using std::string;
namespace db = caffe::db;
template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv);
int main(int argc, char** argv) {
return feature_extraction_pipeline<float>(argc, argv);
// return feature_extraction_pipeline<double>(argc, argv);
}
template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
const int num_required_args = 7;
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
"Usage: extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
" save_feature_dataset_name1[,name2,...] num_mini_batches db_type"
" [CPU/GPU] [DEVICE_ID=0]\n"
"Note: you can extract multiple features in one pass by specifying"
" multiple feature blob names and dataset names separated by ','."
" The names cannot contain white space characters and the number of blobs"
" and datasets must be equal.";
return 1;
}
int arg_pos = num_required_args;
arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
LOG(ERROR)<< "Using GPU";
int device_id = 0;
if (argc > arg_pos + 1) {
device_id = atoi(argv[arg_pos + 1]);
CHECK_GE(device_id, 0);
}
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
}
arg_pos = 0; // the name of the executable
std::string pretrained_binary_proto(argv[++arg_pos]);
// Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features.
/*
layers {
name: "data_layer_name"
type: DATA
data_param {
source: "/path/to/your/images/to/extract/feature/images_leveldb"
mean_file: "/path/to/your/image_mean.binaryproto"
batch_size: 128
crop_size: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
layers {
name: "drop7"
type: DROPOUT
dropout_param {
dropout_ratio: 0.5
}
bottom: "fc7"
top: "fc7"
}
*/
std::string feature_extraction_proto(argv[++arg_pos]);
boost::shared_ptr<Net<Dtype> > feature_extraction_net(
new Net<Dtype>(feature_extraction_proto, caffe::TEST));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
std::string extract_feature_blob_names(argv[++arg_pos]);
std::vector<std::string> blob_names;
boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
std::string save_feature_dataset_names(argv[++arg_pos]);
std::vector<std::string> dataset_names;
boost::split(dataset_names, save_feature_dataset_names,
boost::is_any_of(","));
CHECK_EQ(blob_names.size(), dataset_names.size()) <<
" the number of blob names and dataset names must be equal";
size_t num_features = blob_names.size();
for (size_t i = 0; i < num_features; i++) {
CHECK(feature_extraction_net->has_blob(blob_names[i]))
<< "Unknown feature blob name " << blob_names[i]
<< " in the network " << feature_extraction_proto;
}
int num_mini_batches = atoi(argv[++arg_pos]);
std::vector<boost::shared_ptr<db::DB> > feature_dbs;
std::vector<boost::shared_ptr<db::Transaction> > txns;
const char* db_type = argv[++arg_pos];
for (size_t i = 0; i < num_features; ++i) {
LOG(INFO)<< "Opening dataset " << dataset_names[i];
boost::shared_ptr<db::DB> db(db::GetDB(db_type));
db->Open(dataset_names.at(i), db::NEW);
feature_dbs.push_back(db);
boost::shared_ptr<db::Transaction> txn(db->NewTransaction());
txns.push_back(txn);
}
LOG(ERROR)<< "Extracting Features";
Datum datum;
std::vector<int> image_indices(num_features, 0);
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward();
for (int i = 0; i < num_features; ++i) {
const boost::shared_ptr<Blob<Dtype> > feature_blob =
feature_extraction_net->blob_by_name(blob_names[i]);
int batch_size = feature_blob->num();
int dim_features = feature_blob->count() / batch_size;
const Dtype* feature_blob_data;
for (int n = 0; n < batch_size; ++n) {
datum.set_height(feature_blob->height());
datum.set_width(feature_blob->width());
datum.set_channels(feature_blob->channels());
datum.clear_data();
datum.clear_float_data();
feature_blob_data = feature_blob->cpu_data() +
feature_blob->offset(n);
for (int d = 0; d < dim_features; ++d) {
datum.add_float_data(feature_blob_data[d]);
}
string key_str = caffe::format_int(image_indices[i], 10);
string out;
CHECK(datum.SerializeToString(&out));
txns.at(i)->Put(key_str, out);
++image_indices[i];
if (image_indices[i] % 1000 == 0) {
txns.at(i)->Commit();
txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
}
} // for (int n = 0; n < batch_size; ++n)
} // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
for (int i = 0; i < num_features; ++i) {
if (image_indices[i] % 1000 != 0) {
txns.at(i)->Commit();
}
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
feature_dbs.at(i)->Close();
}
LOG(ERROR)<< "Successfully extracted the features!";
return 0;
}
|