diff --git a/docs/src/outputs/index.rst b/docs/src/outputs/index.rst index 8a68800c1..6a6f00b93 100644 --- a/docs/src/outputs/index.rst +++ b/docs/src/outputs/index.rst @@ -25,6 +25,7 @@ schema they need and add a new section to these pages. momenta velocities charges + spin_multiplicity heat_flux features variants @@ -141,6 +142,12 @@ quantities, i.e. quantities with a well-defined physical meaning. Heat flux, i.e. the amount of energy transferred per unit time, i.e. :math:`\sum_i E_i \times \vec v_i` + .. grid-item-card:: Spin multiplicity + :link: spin-multiplicity-output + :link-type: ref + + The spin multiplicity :math:`2S + 1` of the system. + Machine learning quantities ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/outputs/spin_multiplicity.rst b/docs/src/outputs/spin_multiplicity.rst new file mode 100644 index 000000000..3d1aeeddc --- /dev/null +++ b/docs/src/outputs/spin_multiplicity.rst @@ -0,0 +1,65 @@ +.. _spin-multiplicity-output: + +Spin multiplicity +^^^^^^^^^^^^^^^^^ + +The spin multiplicity of the system is associated with the ``"spin_multiplicity"`` +name, and must have the following metadata: + +.. list-table:: Metadata for spin_multiplicity output + :widths: 2 3 7 + :header-rows: 1 + + * - Metadata + - Names + - Description + + * - keys + - ``"_"`` + - the keys must have a single dimension named ``"_"``, with a single entry + set to ``0``. The spin multiplicity is always a + :py:class:`metatensor.torch.TensorMap` with a single block. + + * - samples + - ``["system"]`` + - the samples must be named ``["system"]``, since the spin multiplicity is + a per-system quantity. When running a batched calculation, there will be + one row per system. + + ``"system"`` must range from 0 to the number of systems given as input + to the model. + + * - components + - + - the spin multiplicity must not have any components + + * - properties + - ``"spin_multiplicity"`` + - the spin multiplicity must have a single property dimension named + ``"spin_multiplicity"``, with a single entry set to ``0``. + +The values represent the spin multiplicity :math:`2S + 1` of the system, where +:math:`S` is the total spin quantum number. The values are dimensionless and +stored as floats (matching the model's dtype), even though they always take +positive integer values. The value must be at least ``1``. + +Common examples: + +- ``1`` for a singlet (:math:`S = 0`) +- ``2`` for a doublet (:math:`S = 1/2`, e.g. a radical with one unpaired electron) +- ``3`` for a triplet (:math:`S = 1`) + +The following simulation engines support the ``"spin_multiplicity"`` output: + +.. grid:: 1 1 1 1 + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ase + :link-type: ref + + |ase-logo| + +In ASE, the spin multiplicity is read from ``atoms.info["spin"]`` and defaults +to ``1`` (singlet) if not set. diff --git a/metatomic-torch/include/metatomic/torch/misc.hpp b/metatomic-torch/include/metatomic/torch/misc.hpp index 84fa4f215..97fc4cafa 100644 --- a/metatomic-torch/include/metatomic/torch/misc.hpp +++ b/metatomic-torch/include/metatomic/torch/misc.hpp @@ -76,7 +76,7 @@ namespace details { /// - a boolean indicating whether this is a known output/input /// - the name of the base output/input (empty if custom) /// - the name of the variant (empty if none) -std::tuple validate_name_and_check_variant( +METATOMIC_TORCH_EXPORT std::tuple validate_name_and_check_variant( const std::string& name ); } diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index 647b685e6..f19de58aa 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -427,6 +427,7 @@ inline std::unordered_set KNOWN_INPUTS_OUTPUTS = { "velocities", "masses", "charges", + "spin_multiplicity", "heat_flux", }; diff --git a/metatomic-torch/src/outputs.cpp b/metatomic-torch/src/outputs.cpp index 9ed8fcd35..16226a993 100644 --- a/metatomic-torch/src/outputs.cpp +++ b/metatomic-torch/src/outputs.cpp @@ -622,6 +622,36 @@ static void check_heat_flux( validate_no_gradients("heat_flux", heat_flux_block); } +/// Check output metadata for spin_multiplicity (per-system scalar). +static void check_spin_multiplicity( + const TensorMap& value, + const std::vector& systems, + const ModelOutput& request +) { + validate_single_block("spin_multiplicity", value); + + if (request->per_atom) { + C10_THROW_ERROR(ValueError, + "invalid 'spin_multiplicity' output: spin_multiplicity is a per-system quantity, " + "but the request indicates `per_atom=True`" + ); + } + validate_atomic_samples("spin_multiplicity", value, systems, request, torch::nullopt); + + auto tensor_options = torch::TensorOptions().device(value->device()); + auto spin_block = TensorMapHolder::block_by_id(value, 0); + + validate_components("spin_multiplicity", spin_block->components(), {}); + + auto expected_properties = torch::make_intrusive( + "spin_multiplicity", + torch::tensor({{0}}, tensor_options) + ); + validate_properties("spin_multiplicity", spin_block, expected_properties); + + validate_no_gradients("spin_multiplicity", spin_block); +} + void metatomic_torch::check_outputs( const std::vector& systems, const c10::Dict& requested, @@ -694,6 +724,8 @@ void metatomic_torch::check_outputs( check_charges(value, systems, request); } else if (base == "heat_flux") { check_heat_flux(value, systems, request); + } else if (base == "spin_multiplicity") { + check_spin_multiplicity(value, systems, request); } else if (name.find("::") != std::string::npos) { // this is a non-standard output, there is nothing to check } else { diff --git a/metatomic-torch/src/units.cpp b/metatomic-torch/src/units.cpp index bd74e1ac6..1aa50d372 100644 --- a/metatomic-torch/src/units.cpp +++ b/metatomic-torch/src/units.cpp @@ -594,7 +594,8 @@ static const auto QUANTITY_DIMS = std::unordered_map{ {"mass", DIM_MASS}, {"velocity", {{1, -1, 0, 0, 0}}}, // length/time {"charge", DIM_CHARGE}, - {"heat_flux", {{3, -3, 1, 0, 0}}}, // energy*velocity + {"heat_flux", {{3, -3, 1, 0, 0}}}, // energy*velocity + {"spin_multiplicity", {{0, 0, 0, 0, 0}}}, // dimensionless }; diff --git a/metatomic-torch/tests/misc.cpp b/metatomic-torch/tests/misc.cpp index 55aae7d1c..fb321b1e4 100644 --- a/metatomic-torch/tests/misc.cpp +++ b/metatomic-torch/tests/misc.cpp @@ -105,3 +105,11 @@ TEST_CASE("Pick variant") { " - 'energy/foo': Variant foo of the output"; CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs), StartsWith(err)); } + +TEST_CASE("Standard outputs") { + // "spin_multiplicity" is recognized as a standard (non-namespaced) output name + auto [known, base, variant] = metatomic_torch::details::validate_name_and_check_variant("spin_multiplicity"); + CHECK(known == true); + CHECK(base == "spin_multiplicity"); + CHECK(variant == ""); +} diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index c319b6593..ec74bcc98 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -111,7 +111,7 @@ TEST_CASE("Models metadata") { void process(const torch::Warning& warning) override { auto expected = std::string( "unknown quantity 'unknown', only [charge energy force heat_flux " - "length mass momentum pressure velocity] are supported" + "length mass momentum pressure spin_multiplicity velocity] are supported" ); CHECK(warning.msg() == expected); } diff --git a/python/metatomic_torch/tests/outputs.py b/python/metatomic_torch/tests/outputs.py index 8e9135a97..5d4e4f85c 100644 --- a/python/metatomic_torch/tests/outputs.py +++ b/python/metatomic_torch/tests/outputs.py @@ -270,3 +270,102 @@ def test_positions_momenta_model(system): assert momenta.block().properties.names == ["momenta"] assert momenta.block().components == [Labels("xyz", torch.tensor([[0], [1], [2]]))] assert len(result["momenta"].blocks()) == 1 + + +class SpinMultiplicityModel(torch.nn.Module): + """A model that requests spin_multiplicity as a system-level output.""" + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return { + "spin_multiplicity": ModelOutput( + quantity="spin_multiplicity", unit="", per_atom=False + ), + } + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + system = systems[0] + spin = float(system.get_data("spin_multiplicity").block(0).values[0, 0]) + energy_value = 10.0 * spin + block = TensorBlock( + values=torch.tensor([[energy_value]] * len(systems), dtype=torch.float64), + samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + return {"energy": TensorMap(Labels("_", torch.tensor([[0]])), [block])} + + +def _make_scalar_output(name, value): + """Create a valid per-system scalar TensorMap for spin_multiplicity.""" + block = TensorBlock( + values=torch.tensor([[value]], dtype=torch.float64), + samples=Labels("system", torch.tensor([[0]])), + components=[], + properties=Labels(name, torch.tensor([[0]])), + ) + return TensorMap(Labels("_", torch.tensor([[0]])), [block]) + + +def _make_spin_multiplicity_model(): + model = SpinMultiplicityModel() + capabilities = ModelCapabilities( + length_unit="angstrom", + atomic_types=[1, 2, 3], + interaction_range=4.3, + outputs={"energy": ModelOutput(per_atom=False, unit="eV")}, + supported_devices=["cpu"], + dtype="float64", + ) + return AtomisticModel(model.eval(), ModelMetadata(), capabilities) + + +def test_spin_multiplicity_valid(system): + """check_consistency=True passes with correctly structured spin_multiplicity.""" + atomistic = _make_spin_multiplicity_model() + + system.add_data("spin_multiplicity", _make_scalar_output("spin_multiplicity", 3.0)) + + options = ModelEvaluationOptions(outputs={"energy": ModelOutput(per_atom=False)}) + result = atomistic([system], options, check_consistency=True) + assert "energy" in result + + +def test_spin_multiplicity_wrong_property_name(system): + """check_consistency catches wrong property name for spin_multiplicity.""" + atomistic = _make_spin_multiplicity_model() + + bad_block = TensorBlock( + values=torch.tensor([[1.0]], dtype=torch.float64), + samples=Labels("system", torch.tensor([[0]])), + components=[], + properties=Labels("wrong", torch.tensor([[0]])), + ) + bad_spin = TensorMap(Labels("_", torch.tensor([[0]])), [bad_block]) + system.add_data("spin_multiplicity", bad_spin) + + options = ModelEvaluationOptions(outputs={"energy": ModelOutput(per_atom=False)}) + with pytest.raises(ValueError, match="spin_multiplicity"): + atomistic([system], options, check_consistency=True) + + +def test_spin_multiplicity_with_components_error(system): + """check_consistency catches spurious components on spin_multiplicity.""" + atomistic = _make_spin_multiplicity_model() + + bad_block = TensorBlock( + values=torch.tensor([[[1.0], [1.0], [1.0]]], dtype=torch.float64), + samples=Labels("system", torch.tensor([[0]])), + components=[Labels("xyz", torch.tensor([[0], [1], [2]]))], + properties=Labels("spin_multiplicity", torch.tensor([[0]])), + ) + bad_spin = TensorMap(Labels("_", torch.tensor([[0]])), [bad_block]) + system.add_data("spin_multiplicity", bad_spin) + + options = ModelEvaluationOptions(outputs={"energy": ModelOutput(per_atom=False)}) + with pytest.raises(ValueError, match="spin_multiplicity"): + atomistic([system], options, check_consistency=True)