@@ -53,6 +53,86 @@ static void read_vector_int_json(
5353
5454/* *****************************************************************************/
5555
56+ #if defined(__GNUC__) || defined(__clang__)
57+ #pragma GCC diagnostic push
58+ #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
59+ #endif
60+
61+ ModelOutputHolder::ModelOutputHolder (
62+ std::string quantity,
63+ std::string unit,
64+ std::string sample_kind,
65+ std::vector<std::string> explicit_gradients_,
66+ std::string description_
67+ ):
68+ description(std::move(description_)),
69+ explicit_gradients(std::move(explicit_gradients_))
70+ {
71+ this ->set_quantity (std::move (quantity));
72+ this ->set_unit (std::move (unit));
73+ this ->set_sample_kind (std::move (sample_kind));
74+ }
75+
76+ ModelOutputHolder::ModelOutputHolder (
77+ std::string quantity,
78+ std::string unit,
79+ bool per_atom_,
80+ std::vector<std::string> explicit_gradients_,
81+ std::string description_
82+ ):
83+ description(std::move(description_)),
84+ explicit_gradients(std::move(explicit_gradients_))
85+ {
86+ this ->set_quantity (std::move (quantity));
87+ this ->set_unit (std::move (unit));
88+ this ->set_per_atom (per_atom_);
89+ }
90+
91+ ModelOutputHolder::ModelOutputHolder (
92+ std::string quantity,
93+ std::string unit,
94+ torch::IValue per_atom_or_sample_kind,
95+ std::vector<std::string> explicit_gradients_,
96+ std::string description_,
97+ torch::optional<bool > per_atom,
98+ torch::optional<std::string> sample_kind
99+ ):
100+ description(std::move(description_)),
101+ explicit_gradients(std::move(explicit_gradients_))
102+ {
103+ this ->set_quantity (std::move (quantity));
104+ this ->set_unit (std::move (unit));
105+
106+ if (per_atom_or_sample_kind.isNone ()) {
107+ // check the kwargs for backward compatibility
108+ if (sample_kind.has_value () && per_atom.has_value ()) {
109+ C10_THROW_ERROR (ValueError, " cannot specify both `per_atom` and `sample_kind`" );
110+ } else if (sample_kind.has_value ()) {
111+ this ->set_sample_kind (sample_kind.value ());
112+ } else if (per_atom.has_value ()) {
113+ this ->set_per_atom (per_atom.value ());
114+ }
115+ } else if (per_atom_or_sample_kind.isBool ()) {
116+ if (per_atom.has_value ()) {
117+ C10_THROW_ERROR (ValueError,
118+ " cannot specify `per_atom` both as a positional and keyword argument"
119+ );
120+ }
121+ this ->set_per_atom (per_atom_or_sample_kind.toBool ());
122+ } else if (per_atom_or_sample_kind.isString ()) {
123+ if (sample_kind.has_value ()) {
124+ C10_THROW_ERROR (ValueError,
125+ " cannot specify `sample_kind` both as a positional and keyword argument"
126+ );
127+ }
128+ this ->set_sample_kind (per_atom_or_sample_kind.toStringRef ());
129+ } else {
130+ C10_THROW_ERROR (ValueError,
131+ " positional argument for `per_atom`/`sample_kind` must be either a boolean or a string"
132+ );
133+ }
134+ }
135+
56136void ModelOutputHolder::set_quantity (std::string quantity) {
57137 if (valid_quantity (quantity)) {
58138 validate_unit (quantity, unit_);
@@ -72,7 +152,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
72152 result[" class" ] = " ModelOutput" ;
73153 result[" quantity" ] = self.quantity ();
74154 result[" unit" ] = self.unit ();
75- result[" per_atom " ] = self.per_atom ;
155+ result[" sample_kind " ] = self.sample_kind () ;
76156 result[" explicit_gradients" ] = self.explicit_gradients ;
77157 result[" description" ] = self.description ;
78158
@@ -112,11 +192,18 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
112192 result->set_unit (data[" unit" ]);
113193 }
114194
115- if (data.contains (" per_atom" )) {
195+ if (data.contains (" sample_kind" )) {
196+ if (!data[" sample_kind" ].is_string ()) {
197+ throw std::runtime_error (" 'sample_kind' in JSON for ModelOutput must be a string" );
198+ }
199+ result->set_sample_kind (data[" sample_kind" ]);
200+ } else if (data.contains (" per_atom" )) {
116201 if (!data[" per_atom" ].is_boolean ()) {
117202 throw std::runtime_error (" 'per_atom' in JSON for ModelOutput must be a boolean" );
118203 }
119- result->per_atom = data[" per_atom" ];
204+ result->set_per_atom (data[" per_atom" ]);
205+ } else {
206+ result->set_sample_kind (" system" );
120207 }
121208
122209 if (data.contains (" explicit_gradients" )) {
@@ -145,11 +232,86 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
145232 return model_output_from_json (data);
146233}
147234
235+ static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
236+ " system" ,
237+ " atom" ,
238+ " atom_pair" ,
239+ };
240+
241+ void ModelOutputHolder::set_sample_kind (std::string sample_kind) {
242+ if (sample_kind == " atom" ) {
243+ this ->set_per_atom_no_deprecation (true );
244+ } else if (sample_kind == " system" ) {
245+ this ->set_per_atom_no_deprecation (false );
246+ } else {
247+ if (SUPPORTED_SAMPLE_KINDS.find (sample_kind) == SUPPORTED_SAMPLE_KINDS.end ()) {
248+ C10_THROW_ERROR (ValueError,
249+ " invalid sample_kind '" + sample_kind + " ': supported values are [" +
250+ torch::str (SUPPORTED_SAMPLE_KINDS) + " ]"
251+ );
252+ }
253+
254+ this ->sample_kind_ = std::move (sample_kind);
255+ }
256+ }
257+
258+ std::string ModelOutputHolder::sample_kind () const {
259+ if (sample_kind_.has_value ()) {
260+ return sample_kind_.value ();
261+ } else if (this ->get_per_atom_no_deprecation ()) {
262+ return " atom" ;
263+ } else {
264+ return " system" ;
265+ }
266+ }
267+
268+ void ModelOutputHolder::set_per_atom (bool per_atom_) {
269+ TORCH_WARN_DEPRECATION (
270+ " `per_atom` is deprecated, please use `sample_kind` instead"
271+ );
272+
273+ this ->set_per_atom_no_deprecation (per_atom_);
274+ }
275+
276+ bool ModelOutputHolder::get_per_atom () const {
277+ TORCH_WARN_DEPRECATION (
278+ " `per_atom` is deprecated, please use `sample_kind` instead"
279+ );
280+
281+ return this ->get_per_atom_no_deprecation ();
282+ }
283+
284+ void ModelOutputHolder::set_per_atom_no_deprecation (bool per_atom) {
285+ this ->per_atom = per_atom;
286+
287+ this ->sample_kind_ = torch::nullopt ;
288+ }
289+
290+ bool ModelOutputHolder::get_per_atom_no_deprecation () const {
291+ if (sample_kind_.has_value ()) {
292+ if (sample_kind_.value () == " atom" ) {
293+ return true ;
294+ } else if (sample_kind_.value () == " system" ) {
295+ return false ;
296+ } else {
297+ C10_THROW_ERROR (
298+ ValueError,
299+ " Can't infer `per_atom` from `sample_kind` '" + this ->sample_kind () + " '. "
300+ " `per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
301+ );
302+ }
303+ }
304+ return per_atom;
305+ }
306+
307+ #if defined(__GNUC__) || defined(__clang__)
308+ #pragma GCC diagnostic pop
309+ #endif
310+
148311/* *****************************************************************************/
149312
150313
151314void ModelCapabilitiesHolder::set_outputs (torch::Dict<std::string, ModelOutput> outputs) {
152-
153315 std::unordered_map<std::string, std::vector<std::string>> variants;
154316 for (const auto & it: outputs) {
155317 auto [is_standard, base, variant] = details::validate_name_and_check_variant (it.key ());
0 commit comments