@@ -53,6 +53,51 @@ static void read_vector_int_json(
5353
5454/* *****************************************************************************/
5555
56+ ModelOutputHolder::ModelOutputHolder (
57+ std::string quantity,
58+ std::string unit,
59+ torch::IValue per_atom_or_sample_kind,
60+ std::vector<std::string> explicit_gradients_,
61+ std::string description_,
62+ torch::optional<bool > per_atom,
63+ torch::optional<std::string> sample_kind
64+ ):
65+ description(std::move(description_)),
66+ explicit_gradients(std::move(explicit_gradients_))
67+ {
68+ this ->set_quantity (std::move (quantity));
69+ this ->set_unit (std::move (unit));
70+
71+ if (per_atom_or_sample_kind.isNone ()) {
72+ // check the kwargs for backward compatibility
73+ if (sample_kind.has_value () && per_atom.has_value ()) {
74+ C10_THROW_ERROR (ValueError, " cannot specify both `per_atom` and `sample_kind`" );
75+ } else if (sample_kind.has_value ()) {
76+ this ->set_sample_kind (sample_kind.value ());
77+ } else if (per_atom.has_value ()) {
78+ this ->set_per_atom (per_atom.value ());
79+ }
80+ } else if (per_atom_or_sample_kind.isBool ()) {
81+ if (per_atom.has_value ()) {
82+ C10_THROW_ERROR (ValueError,
83+ " cannot specify `per_atom` both as a positional and keyword argument"
84+ );
85+ }
86+ this ->set_per_atom (per_atom_or_sample_kind.toBool ());
87+ } else if (per_atom_or_sample_kind.isString ()) {
88+ if (sample_kind.has_value ()) {
89+ C10_THROW_ERROR (ValueError,
90+ " cannot specify `sample_kind` both as a positional and keyword argument"
91+ );
92+ }
93+ this ->set_sample_kind (per_atom_or_sample_kind.toStringRef ());
94+ } else {
95+ C10_THROW_ERROR (ValueError,
96+ " positional argument for `per_atom`/`sample_kind` must be either a boolean or a string"
97+ );
98+ }
99+ }
100+
56101void ModelOutputHolder::set_quantity (std::string quantity) {
57102 if (valid_quantity (quantity)) {
58103 validate_unit (quantity, unit_);
@@ -72,7 +117,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
72117 result[" class" ] = " ModelOutput" ;
73118 result[" quantity" ] = self.quantity ();
74119 result[" unit" ] = self.unit ();
75- result[" per_atom " ] = self.per_atom ;
120+ result[" sample_kind " ] = self.sample_kind () ;
76121 result[" explicit_gradients" ] = self.explicit_gradients ;
77122 result[" description" ] = self.description ;
78123
@@ -112,11 +157,18 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
112157 result->set_unit (data[" unit" ]);
113158 }
114159
115- if (data.contains (" per_atom" )) {
160+ if (data.contains (" sample_kind" )) {
161+ if (!data[" sample_kind" ].is_string ()) {
162+ throw std::runtime_error (" 'sample_kind' in JSON for ModelOutput must be a string" );
163+ }
164+ result->set_sample_kind (data[" sample_kind" ]);
165+ } else if (data.contains (" per_atom" )) {
116166 if (!data[" per_atom" ].is_boolean ()) {
117167 throw std::runtime_error (" 'per_atom' in JSON for ModelOutput must be a boolean" );
118168 }
119- result->per_atom = data[" per_atom" ];
169+ result->set_per_atom (data[" per_atom" ]);
170+ } else {
171+ result->set_sample_kind (" system" );
120172 }
121173
122174 if (data.contains (" explicit_gradients" )) {
@@ -145,6 +197,87 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
145197 return model_output_from_json (data);
146198}
147199
200+ static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
201+ " system" ,
202+ " atom" ,
203+ " atom_pair" ,
204+ };
205+
206+ void ModelOutputHolder::set_sample_kind (std::string sample_kind) {
207+ if (sample_kind == " atom" ) {
208+ this ->set_per_atom_no_deprecation (true );
209+ } else if (sample_kind == " system" ) {
210+ this ->set_per_atom_no_deprecation (false );
211+ } else {
212+ if (SUPPORTED_SAMPLE_KINDS.find (sample_kind) == SUPPORTED_SAMPLE_KINDS.end ()) {
213+ C10_THROW_ERROR (ValueError,
214+ " invalid sample_kind '" + sample_kind + " ': supported values are [" +
215+ torch::str (SUPPORTED_SAMPLE_KINDS) + " ]"
216+ );
217+ }
218+
219+ this ->sample_kind_ = std::move (sample_kind);
220+ }
221+ }
222+
223+ std::string ModelOutputHolder::sample_kind () const {
224+ if (sample_kind_.has_value ()) {
225+ return sample_kind_.value ();
226+ } else if (this ->get_per_atom_no_deprecation ()) {
227+ return " atom" ;
228+ } else {
229+ return " system" ;
230+ }
231+ }
232+
233+ void ModelOutputHolder::set_per_atom (bool per_atom_) {
234+ TORCH_WARN_DEPRECATION (
235+ " `per_atom` is deprecated, please use `sample_kind` instead"
236+ );
237+
238+ this ->set_per_atom_no_deprecation (per_atom_);
239+ }
240+
241+ bool ModelOutputHolder::get_per_atom () const {
242+ TORCH_WARN_DEPRECATION (
243+ " `per_atom` is deprecated, please use `sample_kind` instead"
244+ );
245+
246+ return this ->get_per_atom_no_deprecation ();
247+ }
248+
249+ #if defined(__GNUC__) || defined(__clang__)
250+ #pragma GCC diagnostic push
251+ #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
252+ #endif
253+
254+ void ModelOutputHolder::set_per_atom_no_deprecation (bool per_atom) {
255+ this ->per_atom = per_atom;
256+
257+ this ->sample_kind_ = torch::nullopt ;
258+ }
259+
260+ bool ModelOutputHolder::get_per_atom_no_deprecation () const {
261+ if (sample_kind_.has_value ()) {
262+ if (sample_kind_.value () == " atom" ) {
263+ return true ;
264+ } else if (sample_kind_.value () == " system" ) {
265+ return false ;
266+ } else {
267+ C10_THROW_ERROR (
268+ ValueError,
269+ " Can't infer `per_atom` from `sample_kind` '" + this ->sample_kind () + " '. "
270+ " `per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
271+ );
272+ }
273+ }
274+ return per_atom;
275+ }
276+
277+ #if defined(__GNUC__) || defined(__clang__)
278+ #pragma GCC diagnostic pop
279+ #endif
280+
148281/* *****************************************************************************/
149282
150283
0 commit comments