Skip to content

Commit 910f9c1

Browse files
committed
Add input_records flag that allows to restrict maximum number of records to be used
1 parent a9f1234 commit 910f9c1

5 files changed

Lines changed: 39 additions & 18 deletions

File tree

base/readerutil.h

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ class InputRecordReader {
4343
template <class ProtoClass>
4444
class FileInputRecordReader : public InputRecordReader<ProtoClass> {
4545
public:
46-
explicit FileInputRecordReader(const std::string& filename) : reader(filename), has_prefetched(false) {
46+
explicit FileInputRecordReader(const std::string& filename, const int64 max_records=-1) :
47+
reader(filename),
48+
has_prefetched(false),
49+
max_records_(max_records){
4750
}
4851
virtual ~FileInputRecordReader() override {
4952
reader.Close();
@@ -57,12 +60,16 @@ class FileInputRecordReader : public InputRecordReader<ProtoClass> {
5760
}
5861
*proto = std::move(prefetched_proto);
5962
has_prefetched = false;
60-
return true;
63+
if (max_records_ != 0) {
64+
max_records_--;
65+
return true;
66+
}
67+
return false;
6168
}
6269

6370
virtual bool ReachedEnd() override {
6471
std::lock_guard<std::mutex> lock(reader_mutex);
65-
return !has_prefetched && !PrefetchProto();
72+
return (!has_prefetched && !PrefetchProto()) || max_records_ == 0;
6673
}
6774

6875
private:
@@ -79,13 +86,16 @@ class FileInputRecordReader : public InputRecordReader<ProtoClass> {
7986
ProtoClass prefetched_proto;
8087
bool has_prefetched;
8188
std::mutex reader_mutex;
89+
int64 max_records_;
8290
};
8391

8492

8593
template <>
8694
class FileInputRecordReader<std::string> : public InputRecordReader<std::string> {
8795
public:
88-
explicit FileInputRecordReader(const std::string& filename) : file(filename) {
96+
explicit FileInputRecordReader(const std::string& filename, const int64 max_records=-1) :
97+
file(filename),
98+
max_records_(max_records) {
8999
CHECK(exists(filename)) << "File '" << filename << "' does not exist!";
90100
}
91101
virtual ~FileInputRecordReader() override {
@@ -103,17 +113,23 @@ class FileInputRecordReader<std::string> : public InputRecordReader<std::string>
103113
}
104114
std::getline(file, *s); // Read until we get a non-empty line.
105115
}
106-
return true;
116+
if (max_records_ != 0) {
117+
max_records_--;
118+
return true;
119+
}
120+
return false;
107121
}
108122

109123
virtual bool ReachedEnd() override {
110124
std::lock_guard<std::mutex> lock(filemutex);
111-
return file.eof();
125+
return file.eof() || max_records_ == 0;
112126
}
113127
private:
114128
inline bool exists (const std::string& name) {
115129
return ( access( name.c_str(), F_OK ) != -1 );
116130
}
131+
132+
int64 max_records_;
117133
};
118134

119135
template <class T>
@@ -188,17 +204,20 @@ class RecordInput {
188204
template <class T>
189205
class FileRecordInput : public RecordInput<T> {
190206
public:
191-
explicit FileRecordInput(const std::string& filename) : filename_(filename) {
207+
explicit FileRecordInput(const std::string& filename, const int64 max_records=-1) :
208+
filename_(filename),
209+
max_records_(max_records) {
192210
}
193211
virtual ~FileRecordInput() override {
194212
}
195213

196214
virtual InputRecordReader<T>* CreateReader() override {
197-
return new FileInputRecordReader<T>(filename_);
215+
return new FileInputRecordReader<T>(filename_, max_records_);
198216
}
199217

200218
private:
201219
std::string filename_;
220+
int64 max_records_;
202221
};
203222

204223
/**
@@ -261,7 +280,6 @@ class CrossValidationReader : public InputRecordReader<T> {
261280
if ((training_ && (row_id_ % num_folds_) != fold_id_) ||
262281
(!training_ && (row_id_ % num_folds_) == fold_id_)) {
263282
return underlying_reader_->Read(s);
264-
break;
265283
} else {
266284
T tmp;
267285
underlying_reader_->Read(&tmp);

n2p/training/eval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ int main(int argc, char** argv) {
2626
google::ParseCommandLineFlags(&argc, &argv, true);
2727
google::InitGoogleLogging(argv[0]);
2828

29-
return LearningMain<Query>([](const Query &record) {
29+
return EvalMain<Query>([](const Query &record) {
3030
return record;
3131
});
3232
}

n2p/training/eval_internal.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
using nice2protos::Query;
3636

3737
DEFINE_string(model, "model", "File prefix for model to evaluate.");
38+
DEFINE_int64(input_records, -1, "Number of input records to use.");
3839

3940
DEFINE_string(input, "testdata", "Input file with objects to be used for evaluation.");
4041
DEFINE_bool(debug_stats, false, "If specifies, only outputs debug stats of a trained model.");
@@ -117,7 +118,7 @@ void Evaluate(RecordInput<InputType>* evaluation_data, GraphInference* inference
117118
}
118119

119120
template <class InputType>
120-
int LearningMain(Adapter<InputType> adapter) {
121+
int EvalMain(Adapter<InputType> adapter) {
121122
if (FLAGS_debug_stats) {
122123
GraphInference inference;
123124
inference.LoadModel(FLAGS_model);
@@ -128,7 +129,7 @@ int LearningMain(Adapter<InputType> adapter) {
128129
GraphInference inference;
129130
std::unique_ptr<RecordInput<InputType>> input;
130131

131-
input.reset(new FileRecordInput<InputType>(FLAGS_input));
132+
input.reset(new FileRecordInput<InputType>(FLAGS_input, FLAGS_input_records));
132133
inference.LoadModel(FLAGS_model);
133134
PrecisionStats total_stats;
134135
Evaluate(input.get(), &inference, &total_stats, error_stats.get(), adapter);

n2p/training/eval_json.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ int main(int argc, char** argv) {
2626
google::ParseCommandLineFlags(&argc, &argv, true);
2727
google::InitGoogleLogging(argv[0]);
2828

29-
return LearningMain<Query>([](const Query &record) {
29+
return EvalMain<Query>([](const Query &record) {
3030
return record;
3131
});
3232
}

n2p/training/train_internal.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ const std::string PROP_INITIAL_LEARN_RATE_AND_PASS_LEARN_RATE_UPDATE_PL = "prop_
4545
DEFINE_string(input, "testdata", "Input file with training data objects");
4646
DEFINE_string(out_model, "model", "File prefix for output models");
4747
DEFINE_int32(num_training_passes, 24, "Number of passes in training.");
48+
DEFINE_int64(input_records, -1, "Number of input records to use.");
4849

4950
DEFINE_double(start_learning_rate, 0.1, "Initial learning rate");
5051
DEFINE_double(stop_learning_rate, 0.0001, "Stop learning if learning rate falls below the value");
@@ -249,11 +250,11 @@ int LearningMain(Adapter<InputType> adapter) {
249250
for (int fold_id = 0; fold_id < FLAGS_cross_validation_folds; ++fold_id) {
250251
GraphInference inference;
251252
std::unique_ptr<RecordInput<InputType>> training_data(
252-
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(FLAGS_input),
253-
fold_id, FLAGS_cross_validation_folds, true)));
253+
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(
254+
FLAGS_input, FLAGS_input_records), fold_id, FLAGS_cross_validation_folds, true)));
254255
std::unique_ptr<RecordInput<InputType>> validation_data(
255-
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(FLAGS_input),
256-
fold_id, FLAGS_cross_validation_folds, false)));
256+
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(
257+
FLAGS_input, FLAGS_input_records), fold_id, FLAGS_cross_validation_folds, false)));
257258
LOG(INFO) << "Training fold " << fold_id;
258259
InitTrain(training_data.get(), &inference, adapter);
259260
if (FLAGS_training_method.compare(PL_TRAIN_NAME) == 0) {
@@ -286,7 +287,8 @@ int LearningMain(Adapter<InputType> adapter) {
286287
LOG(INFO) << "Running structured training...";
287288
// Structured training.
288289
GraphInference inference;
289-
std::unique_ptr<RecordInput<InputType>> input(new ShuffledCacheInput<InputType>(new FileRecordInput<InputType>(FLAGS_input)));
290+
std::unique_ptr<RecordInput<InputType>> input(new ShuffledCacheInput<InputType>(
291+
new FileRecordInput<InputType>(FLAGS_input, FLAGS_input_records)));
290292
InitTrain(input.get(), &inference, adapter);
291293
LOG(INFO) << "Training inited...";
292294
if (FLAGS_training_method.compare(PL_TRAIN_NAME) == 0) {

0 commit comments

Comments
 (0)