-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathscenario.cpp
More file actions
110 lines (92 loc) · 2.88 KB
/
scenario.cpp
File metadata and controls
110 lines (92 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include "scenario.h"
#include <cassert>
#include <string>
#include <iostream>
Scenario::Scenario()
:m_name ("foo") {
}
Scenario::Scenario(std::string name)
:m_name (name) {
}
Scenario::Scenario(std::string name, int epoch_count, double keep_rate)
:m_name(name) {
m_dropouts.assign(epoch_count, keep_rate);
}
Scenario::Scenario(std::string name,
int epoch_count,
double keep_begin_rate,
double keep_end_rate,
std::function<double(int)> generator)
:m_name(name) {
bool incremental = keep_end_rate > keep_begin_rate;
double diff = generator(epoch_count-1);
double keep_diff = keep_end_rate - keep_begin_rate;
if (!incremental) {
keep_diff *= -1;
}
double scale = diff / keep_diff;
m_dropouts.push_back(keep_begin_rate);
for (int i = 1; i < epoch_count; i++) {
double generated = generator(i);
if (incremental) {
m_dropouts.push_back(keep_begin_rate + generated/scale);
} else {
m_dropouts.push_back(keep_begin_rate - generated/scale);
}
}
}
Scenario::Scenario(std::string name,
int epoch_count,
int epoch_to_skip,
double keep_begin_rate,
double keep_end_rate,
std::function<double(int)> generator)
:m_name(name) {
m_dropouts.assign(epoch_to_skip, 1.0);
// generate for epoch_count - epoch_to_skip epochs
// epoch_to_skip'th rate is keep_begin_rate
// (epoch_count-1)'th rate is keep_end_rate
bool incremental = keep_end_rate > keep_begin_rate;
int n_epoch_to_generate = epoch_count - epoch_to_skip -1;
double diff = generator(n_epoch_to_generate);
double keep_diff = keep_end_rate - keep_begin_rate;
if (!incremental) {
keep_diff *= -1;
}
double scale = diff / (keep_diff);
m_dropouts.push_back(keep_begin_rate);
for (int i = epoch_to_skip+1; i < epoch_count; i++) {
if (incremental) {
m_dropouts.push_back(keep_begin_rate + generator(i-epoch_to_skip)/scale);
} else {
m_dropouts.push_back(keep_begin_rate - generator(i-epoch_to_skip)/scale);
}
}
}
double Scenario::getKeepRate(int epoch) const {
assert(epoch < m_dropouts.size());
return m_dropouts[epoch];
}
double Scenario::averageDropout() const {
double sum = 0.0f;
for (double value : m_dropouts) {
sum += value;
}
return sum / m_dropouts.size();
}
bool Scenario::isEnabled() const {
return m_dropouts.size() > 0;
}
int Scenario::size() const {
return m_dropouts.size();
}
std::string Scenario::name() const {
return m_name;
}
void Scenario::print() const {
std::cout << m_name << std::endl;
for (double dropout_ratio : m_dropouts) {
std::cout << dropout_ratio << " ";
}
std::cout << std::endl;
}