99#include < boost/algorithm/string.hpp>
1010#include < boost/algorithm/string/split.hpp>
1111#include < cmath>
12+ #include < numeric> // for std::accumulate
1213
14+ #include " dtlmod/DTLException.hpp"
1315#include " dtlmod/ReductionMethod.hpp"
1416#include " dtlmod/Variable.hpp"
1517
@@ -21,17 +23,16 @@ namespace dtlmod {
2123
2224class ParameterizedDecimation {
2325 friend class DecimationReductionMethod ;
26+ std::shared_ptr<Variable> var_; // The variable to which this parameterized decimation is applied
27+
2428 std::vector<size_t > stride_;
2529 std::string interpolation_method_ = " " ;
2630 double cost_per_element_;
2731
2832 std::vector<size_t > reduced_shape_;
2933 std::unordered_map<sg4::ActorPtr, std::pair<std::vector<size_t >, std::vector<size_t >>> reduced_local_start_and_count_;
30- size_t element_size_;
3134
3235protected:
33- [[nodiscard]] const std::vector<size_t >& get_stride () const { return stride_; }
34-
3536 void set_reduced_shape (const std::vector<size_t >& reduced_shape) { reduced_shape_ = reduced_shape; }
3637 void set_reduced_local_start_and_count (
3738 std::unordered_map<sg4::ActorPtr, std::pair<std::vector<size_t >, std::vector<size_t >>>
@@ -40,36 +41,48 @@ class ParameterizedDecimation {
4041 reduced_local_start_and_count_ = reduced_local_start_and_count;
4142 }
4243
44+ [[nodiscard]] const std::vector<size_t >& get_stride () const { return stride_; }
45+
46+ [[nodiscard]] const std::vector<size_t >& get_reduced_shape () const { return reduced_shape_; }
47+
4348 [[nodiscard]] size_t get_global_reduced_size () const
4449 {
45- size_t total_size = element_size_;
46- for (const auto & s : reduced_shape_)
47- total_size *= s;
48- return total_size;
50+ return std::accumulate (reduced_shape_.begin (), reduced_shape_.end (), var_->get_element_size (), std::multiplies<>{});
4951 }
5052
5153 [[nodiscard]] size_t get_local_reduced_size () const
5254 {
53- size_t total_size = element_size_;
54- auto issuer = sg4::Actor::self ();
55- for (const auto & c : reduced_local_start_and_count_.at (issuer).second )
56- total_size *= c;
57- return total_size;
55+ auto start_and_count = reduced_local_start_and_count_.at (sg4::Actor::self ()).second ;
56+ return std::accumulate (start_and_count.begin (), start_and_count.end (), var_->get_element_size (),
57+ std::multiplies<>{});
5858 }
59- [[nodiscard]] const std::vector< size_t >& get_reduced_shape () const { return reduced_shape_; }
59+
6060 [[nodiscard]] const std::pair<std::vector<size_t >, std::vector<size_t >>&
6161 get_reduced_start_and_count_for (sg4::ActorPtr publisher) const
6262 {
6363 return reduced_local_start_and_count_.at (publisher);
6464 }
6565
66+ [[nodiscard]] double get_flop_amount_to_decimate () const
67+ {
68+ double amount = cost_per_element_;
69+ if (interpolation_method_.empty ()) {
70+ amount *= var_->get_local_size ();
71+ } else if (interpolation_method_ == " linear" ) {
72+ amount = 2 * amount * var_->get_local_size ();
73+ } else if (interpolation_method_ == " quadratic" ) {
74+ amount = 4 * amount * var_->get_local_size ();
75+ } else if (interpolation_method_ == " cubic" ) {
76+ amount = 8 * amount * var_->get_local_size ();
77+ } else
78+ throw UnknownDecimationInterpolationException (XBT_THROW_POINT, interpolation_method_.c_str ());
79+ return amount;
80+ }
81+
6682public:
67- ParameterizedDecimation (const std::vector<size_t > stride, const std::string interpolation_method,
68- double cost_per_element, size_t element_size)
69- : stride_(stride)
70- , interpolation_method_(interpolation_method)
71- , cost_per_element_(cost_per_element)
72- , element_size_(element_size)
83+ ParameterizedDecimation (std::shared_ptr<Variable> var, const std::vector<size_t > stride,
84+ const std::string interpolation_method, double cost_per_element)
85+ : var_(var), stride_(stride), interpolation_method_(interpolation_method), cost_per_element_(cost_per_element)
7386 {
7487 }
7588};
@@ -89,31 +102,34 @@ class DecimationReductionMethod : public ReductionMethod {
89102 if (key == " stride" ) {
90103 std::vector<std::string> tokens;
91104 boost::split (tokens, value, boost::is_any_of (" ," ), boost::token_compress_on);
92- // TODO Add Sanity check that the number of tokens equals the number of dimension of var
93- for (const auto t : tokens)
105+ if (var->get_shape ().size () != tokens.size ())
106+ throw InconsistentDecimationStrideException (
107+ XBT_THROW_POINT, " Decimation Stride and Variable Shape vectors must have the same size. Stride: " +
108+ std::to_string (tokens.size ()) +
109+ " , Shape: " + std::to_string (var->get_shape ().size ()));
110+ for (const auto & t : tokens)
94111 stride.push_back (std::stoul (t));
95112 } else if (key == " interpolation" ) {
96113 interpolation_method = value;
97114 } else if (key == " cost_per_element" ) {
98115 cost_per_element = std::stod (value);
99- } // else
100- // TODO handle invalid key
116+ } else
117+ throw UnknownDecimationOptionException (XBT_THROW_POINT, key. c_str ());
101118 }
102119
103120 per_variable_parameterizations_.try_emplace (
104- var, std::make_shared<ParameterizedDecimation>(stride, interpolation_method, cost_per_element,
105- var->get_element_size ()));
121+ var, std::make_shared<ParameterizedDecimation>(var, stride, interpolation_method, cost_per_element));
106122 }
107123
108124 void reduce_variable (std::shared_ptr<Variable> var)
109125 {
110126 auto parameterization = per_variable_parameterizations_[var];
111- auto shape = var->get_shape ();
127+ auto original_shape = var->get_shape ();
112128 auto stride = parameterization->get_stride ();
113129
114130 std::vector<size_t > reduced_shape;
115131 size_t i = 0 ;
116- for (auto dim_size : shape )
132+ for (auto dim_size : original_shape )
117133 reduced_shape.push_back (std::ceil (dim_size / (stride[i++] * 1.0 )));
118134
119135 std::unordered_map<sg4::ActorPtr, std::pair<std::vector<size_t >, std::vector<size_t >>>
@@ -123,11 +139,11 @@ class DecimationReductionMethod : public ReductionMethod {
123139 std::vector<size_t > reduced_start;
124140 std::vector<size_t > reduced_count;
125141
126- for (size_t i = 0 ; i < shape .size (); i++) {
142+ for (size_t i = 0 ; i < original_shape .size (); i++) {
127143 // Sanity checks that shape, start, and count have the same size have already been done
128144 size_t r_start = std::ceil (start[i] / (stride[i] * 1.0 ));
129145 size_t r_next_start =
130- std::min (shape [i], static_cast <size_t >(std::ceil ((start[i] + count[i]) / (stride[i] * 1.0 ))));
146+ std::min (original_shape [i], static_cast <size_t >(std::ceil ((start[i] + count[i]) / (stride[i] * 1.0 ))));
131147 XBT_CDEBUG (dtlmod, " Dim %lu: stride = %lu, Start = %lu, r_start = %lu, Count = %lu, r_count = %lu" , i,
132148 stride[i], start[i], r_start, count[i], r_next_start - r_start);
133149 reduced_start.push_back (r_start);
@@ -149,10 +165,17 @@ class DecimationReductionMethod : public ReductionMethod {
149165 {
150166 return per_variable_parameterizations_.at (var)->get_local_reduced_size ();
151167 }
168+
169+ [[nodiscard]] double get_flop_amount_to_reduce_variable (std::shared_ptr<Variable> var) const override
170+ {
171+ return per_variable_parameterizations_.at (var)->get_flop_amount_to_decimate ();
172+ }
173+
152174 [[nodiscard]] const std::vector<size_t >& get_reduced_variable_shape (std::shared_ptr<Variable> var) const override
153175 {
154176 return per_variable_parameterizations_.at (var)->get_reduced_shape ();
155177 }
178+
156179 [[nodiscard]] const std::pair<std::vector<size_t >, std::vector<size_t >>&
157180 get_reduced_start_and_count_for (std::shared_ptr<Variable> var, sg4::ActorPtr publisher) const override
158181 {
0 commit comments