Skip to content

Commit ea5e7fd

Browse files
pfebrerLuthaf
andcommitted
Change per_atom into sample_kind while keeping backward compatibility
Co-Authored-By: Guillaume Fraux <guillaume.fraux@epfl.ch>
1 parent 724deff commit ea5e7fd

20 files changed

Lines changed: 490 additions & 111 deletions

File tree

docs/src/engines/plumed-model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def forward(
2525
if "features" not in outputs:
2626
return {}
2727

28-
if outputs["features"].per_atom:
28+
if outputs["features"].sample_kind == "atom":
2929
raise ValueError("per-atoms features are not supported in this model")
3030

3131
# PLUMED will first call the model with 0 atoms to get the size of the

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,45 @@ using ModelMetadata = torch::intrusive_ptr<ModelMetadataHolder>;
3232
/// Information about one of the quantity a model can compute
3333
class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder {
3434
public:
35-
ModelOutputHolder() = default;
35+
ModelOutputHolder(): ModelOutputHolder("", "", "system", {}, "") {}
3636

3737
/// Initialize `ModelOutput` with the given data
38+
ModelOutputHolder(
39+
std::string quantity,
40+
std::string unit,
41+
std::string sample_kind,
42+
std::vector<std::string> explicit_gradients_,
43+
std::string description_
44+
);
45+
46+
// for overload resolution
47+
ModelOutputHolder(
48+
std::string quantity,
49+
std::string unit,
50+
const char* sample_kind,
51+
std::vector<std::string> explicit_gradients_,
52+
std::string description_
53+
): ModelOutputHolder(std::move(quantity), std::move(unit), std::string(sample_kind), std::move(explicit_gradients_), std::move(description_)) {};
54+
55+
/// For backward compatibility in the C++ API (per_atom argument)
3856
ModelOutputHolder(
3957
std::string quantity,
4058
std::string unit,
4159
bool per_atom_,
4260
std::vector<std::string> explicit_gradients_,
4361
std::string description_
44-
):
45-
description(std::move(description_)),
46-
per_atom(per_atom_),
47-
explicit_gradients(std::move(explicit_gradients_))
48-
{
49-
this->set_quantity(std::move(quantity));
50-
this->set_unit(std::move(unit));
51-
}
62+
);
63+
64+
/// For backward compatibility in the Python API
65+
ModelOutputHolder(
66+
std::string quantity,
67+
std::string unit,
68+
torch::IValue per_atom_or_sample_kind,
69+
std::vector<std::string> explicit_gradients_,
70+
std::string description_,
71+
torch::optional<bool> per_atom = torch::nullopt,
72+
torch::optional<std::string> sample_kind = torch::nullopt
73+
);
5274

5375
~ModelOutputHolder() override = default;
5476

@@ -72,9 +94,22 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
7294
/// set the unit of the output
7395
void set_unit(std::string unit);
7496

75-
/// is the output defined per-atom or for the overall structure
97+
/// The setter and getter for `per_atom` that are used in TorchBind, which
98+
/// allow us to raise an error if `sample_kind` can't be mapped to a boolean
99+
/// value for `per_atom`.
100+
void set_per_atom(bool per_atom);
101+
bool get_per_atom() const;
102+
103+
/// This is deprecated in favor of `sample_kind`, and kept for backward compatibility reasons only.
104+
[[deprecated("use sample_kind instead")]]
76105
bool per_atom = false;
77106

107+
/// Get the sample kind of the output. TODO: explain
108+
std::string sample_kind() const;
109+
110+
/// Set the `sample_kind` of the output.
111+
void set_sample_kind(std::string sample_kind);
112+
78113
/// Which gradients should be computed eagerly and stored inside the output
79114
/// `TensorMap`
80115
std::vector<std::string> explicit_gradients;
@@ -85,8 +120,12 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
85120
static ModelOutput from_json(std::string_view json);
86121

87122
private:
123+
void set_per_atom_no_deprecation(bool per_atom);
124+
bool get_per_atom_no_deprecation() const;
125+
88126
std::string quantity_;
89127
std::string unit_;
128+
torch::optional<std::string> sample_kind_;
90129
};
91130

92131

metatomic-torch/src/model.cpp

Lines changed: 166 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,86 @@ static void read_vector_int_json(
5353

5454
/******************************************************************************/
5555

56+
#if defined(__GNUC__) || defined(__clang__)
57+
#pragma GCC diagnostic push
58+
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
59+
#endif
60+
61+
ModelOutputHolder::ModelOutputHolder(
62+
std::string quantity,
63+
std::string unit,
64+
std::string sample_kind,
65+
std::vector<std::string> explicit_gradients_,
66+
std::string description_
67+
):
68+
description(std::move(description_)),
69+
explicit_gradients(std::move(explicit_gradients_))
70+
{
71+
this->set_quantity(std::move(quantity));
72+
this->set_unit(std::move(unit));
73+
this->set_sample_kind(std::move(sample_kind));
74+
}
75+
76+
ModelOutputHolder::ModelOutputHolder(
77+
std::string quantity,
78+
std::string unit,
79+
bool per_atom_,
80+
std::vector<std::string> explicit_gradients_,
81+
std::string description_
82+
):
83+
description(std::move(description_)),
84+
explicit_gradients(std::move(explicit_gradients_))
85+
{
86+
this->set_quantity(std::move(quantity));
87+
this->set_unit(std::move(unit));
88+
this->set_per_atom(per_atom_);
89+
}
90+
91+
ModelOutputHolder::ModelOutputHolder(
92+
std::string quantity,
93+
std::string unit,
94+
torch::IValue per_atom_or_sample_kind,
95+
std::vector<std::string> explicit_gradients_,
96+
std::string description_,
97+
torch::optional<bool> per_atom,
98+
torch::optional<std::string> sample_kind
99+
):
100+
description(std::move(description_)),
101+
explicit_gradients(std::move(explicit_gradients_))
102+
{
103+
this->set_quantity(std::move(quantity));
104+
this->set_unit(std::move(unit));
105+
106+
if (per_atom_or_sample_kind.isNone()) {
107+
// check the kwargs for backward compatibility
108+
if (sample_kind.has_value() && per_atom.has_value()) {
109+
C10_THROW_ERROR(ValueError, "cannot specify both `per_atom` and `sample_kind`");
110+
} else if (sample_kind.has_value()) {
111+
this->set_sample_kind(sample_kind.value());
112+
} else if (per_atom.has_value()) {
113+
this->set_per_atom(per_atom.value());
114+
}
115+
} else if (per_atom_or_sample_kind.isBool()) {
116+
if (per_atom.has_value()) {
117+
C10_THROW_ERROR(ValueError,
118+
"cannot specify `per_atom` both as a positional and keyword argument"
119+
);
120+
}
121+
this->set_per_atom(per_atom_or_sample_kind.toBool());
122+
} else if (per_atom_or_sample_kind.isString()) {
123+
if (sample_kind.has_value()) {
124+
C10_THROW_ERROR(ValueError,
125+
"cannot specify `sample_kind` both as a positional and keyword argument"
126+
);
127+
}
128+
this->set_sample_kind(per_atom_or_sample_kind.toStringRef());
129+
} else {
130+
C10_THROW_ERROR(ValueError,
131+
"positional argument for `per_atom`/`sample_kind` must be either a boolean or a string"
132+
);
133+
}
134+
}
135+
56136
void ModelOutputHolder::set_quantity(std::string quantity) {
57137
if (valid_quantity(quantity)) {
58138
validate_unit(quantity, unit_);
@@ -72,7 +152,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
72152
result["class"] = "ModelOutput";
73153
result["quantity"] = self.quantity();
74154
result["unit"] = self.unit();
75-
result["per_atom"] = self.per_atom;
155+
result["sample_kind"] = self.sample_kind();
76156
result["explicit_gradients"] = self.explicit_gradients;
77157
result["description"] = self.description;
78158

@@ -112,11 +192,18 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
112192
result->set_unit(data["unit"]);
113193
}
114194

115-
if (data.contains("per_atom")) {
195+
if (data.contains("sample_kind")) {
196+
if (!data["sample_kind"].is_string()) {
197+
throw std::runtime_error("'sample_kind' in JSON for ModelOutput must be a string");
198+
}
199+
result->set_sample_kind(data["sample_kind"]);
200+
} else if (data.contains("per_atom")) {
116201
if (!data["per_atom"].is_boolean()) {
117202
throw std::runtime_error("'per_atom' in JSON for ModelOutput must be a boolean");
118203
}
119-
result->per_atom = data["per_atom"];
204+
result->set_per_atom(data["per_atom"]);
205+
} else {
206+
result->set_sample_kind("system");
120207
}
121208

122209
if (data.contains("explicit_gradients")) {
@@ -145,11 +232,86 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
145232
return model_output_from_json(data);
146233
}
147234

235+
static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
236+
"system",
237+
"atom",
238+
"atom_pair",
239+
};
240+
241+
void ModelOutputHolder::set_sample_kind(std::string sample_kind) {
242+
if (sample_kind == "atom") {
243+
this->set_per_atom_no_deprecation(true);
244+
} else if (sample_kind == "system") {
245+
this->set_per_atom_no_deprecation(false);
246+
} else {
247+
if (SUPPORTED_SAMPLE_KINDS.find(sample_kind) == SUPPORTED_SAMPLE_KINDS.end()) {
248+
C10_THROW_ERROR(ValueError,
249+
"invalid sample_kind '" + sample_kind + "': supported values are [" +
250+
torch::str(SUPPORTED_SAMPLE_KINDS) + "]"
251+
);
252+
}
253+
254+
this->sample_kind_ = std::move(sample_kind);
255+
}
256+
}
257+
258+
std::string ModelOutputHolder::sample_kind() const {
259+
if (sample_kind_.has_value()) {
260+
return sample_kind_.value();
261+
} else if (this->get_per_atom_no_deprecation()) {
262+
return "atom";
263+
} else {
264+
return "system";
265+
}
266+
}
267+
268+
void ModelOutputHolder::set_per_atom(bool per_atom_) {
269+
TORCH_WARN_DEPRECATION(
270+
"`per_atom` is deprecated, please use `sample_kind` instead"
271+
);
272+
273+
this->set_per_atom_no_deprecation(per_atom_);
274+
}
275+
276+
bool ModelOutputHolder::get_per_atom() const {
277+
TORCH_WARN_DEPRECATION(
278+
"`per_atom` is deprecated, please use `sample_kind` instead"
279+
);
280+
281+
return this->get_per_atom_no_deprecation();
282+
}
283+
284+
void ModelOutputHolder::set_per_atom_no_deprecation(bool per_atom) {
285+
this->per_atom = per_atom;
286+
287+
this->sample_kind_ = torch::nullopt;
288+
}
289+
290+
bool ModelOutputHolder::get_per_atom_no_deprecation() const {
291+
if (sample_kind_.has_value()) {
292+
if (sample_kind_.value() == "atom") {
293+
return true;
294+
} else if (sample_kind_.value() == "system") {
295+
return false;
296+
} else {
297+
C10_THROW_ERROR(
298+
ValueError,
299+
"Can't infer `per_atom` from `sample_kind` '" + this->sample_kind() + "'. "
300+
"`per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
301+
);
302+
}
303+
}
304+
return per_atom;
305+
}
306+
307+
#if defined(__GNUC__) || defined(__clang__)
308+
#pragma GCC diagnostic pop
309+
#endif
310+
148311
/******************************************************************************/
149312

150313

151314
void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput> outputs) {
152-
153315
std::unordered_map<std::string, std::vector<std::string>> variants;
154316
for (const auto& it: outputs) {
155317
auto [is_standard, base, variant] = details::validate_name_and_check_variant(it.key());

0 commit comments

Comments
 (0)