@@ -43,7 +43,10 @@ class InputRecordReader {
4343template <class ProtoClass >
4444class FileInputRecordReader : public InputRecordReader <ProtoClass> {
4545 public:
46- explicit FileInputRecordReader (const std::string& filename) : reader(filename), has_prefetched(false ) {
46+ explicit FileInputRecordReader (const std::string& filename, const int64 max_records=-1 ) :
47+ reader(filename),
48+ has_prefetched(false ),
49+ max_records_(max_records){
4750 }
4851 virtual ~FileInputRecordReader () override {
4952 reader.Close ();
@@ -57,12 +60,16 @@ class FileInputRecordReader : public InputRecordReader<ProtoClass> {
5760 }
5861 *proto = std::move (prefetched_proto);
5962 has_prefetched = false ;
60- return true ;
63+ if (max_records_ != 0 ) {
64+ max_records_--;
65+ return true ;
66+ }
67+ return false ;
6168 }
6269
6370 virtual bool ReachedEnd () override {
6471 std::lock_guard<std::mutex> lock (reader_mutex);
65- return !has_prefetched && !PrefetchProto ();
72+ return ( !has_prefetched && !PrefetchProto ()) || max_records_ == 0 ;
6673 }
6774
6875 private:
@@ -79,13 +86,16 @@ class FileInputRecordReader : public InputRecordReader<ProtoClass> {
7986 ProtoClass prefetched_proto;
8087 bool has_prefetched;
8188 std::mutex reader_mutex;
89+ int64 max_records_;
8290};
8391
8492
8593template <>
8694class FileInputRecordReader <std::string> : public InputRecordReader<std::string> {
8795 public:
88- explicit FileInputRecordReader (const std::string& filename) : file(filename) {
96+ explicit FileInputRecordReader (const std::string& filename, const int64 max_records=-1 ) :
97+ file(filename),
98+ max_records_(max_records) {
8999 CHECK (exists (filename)) << " File '" << filename << " ' does not exist!" ;
90100 }
91101 virtual ~FileInputRecordReader () override {
@@ -103,17 +113,23 @@ class FileInputRecordReader<std::string> : public InputRecordReader<std::string>
103113 }
104114 std::getline (file, *s); // Read until we get a non-empty line.
105115 }
106- return true ;
116+ if (max_records_ != 0 ) {
117+ max_records_--;
118+ return true ;
119+ }
120+ return false ;
107121 }
108122
109123 virtual bool ReachedEnd () override {
110124 std::lock_guard<std::mutex> lock (filemutex);
111- return file.eof ();
125+ return file.eof () || max_records_ == 0 ;
112126 }
113127 private:
114128 inline bool exists (const std::string& name) {
115129 return ( access ( name.c_str (), F_OK ) != -1 );
116130 }
131+
132+ int64 max_records_;
117133};
118134
119135template <class T >
@@ -188,17 +204,20 @@ class RecordInput {
188204template <class T >
189205class FileRecordInput : public RecordInput <T> {
190206public:
191- explicit FileRecordInput (const std::string& filename) : filename_(filename) {
207+ explicit FileRecordInput (const std::string& filename, const int64 max_records=-1 ) :
208+ filename_(filename),
209+ max_records_(max_records) {
192210 }
193211 virtual ~FileRecordInput () override {
194212 }
195213
196214 virtual InputRecordReader<T>* CreateReader () override {
197- return new FileInputRecordReader<T>(filename_);
215+ return new FileInputRecordReader<T>(filename_, max_records_ );
198216 }
199217
200218private:
201219 std::string filename_;
220+ int64 max_records_;
202221};
203222
204223/* *
@@ -261,7 +280,6 @@ class CrossValidationReader : public InputRecordReader<T> {
261280 if ((training_ && (row_id_ % num_folds_) != fold_id_) ||
262281 (!training_ && (row_id_ % num_folds_) == fold_id_)) {
263282 return underlying_reader_->Read (s);
264- break ;
265283 } else {
266284 T tmp;
267285 underlying_reader_->Read (&tmp);
0 commit comments