diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index 75b445085f..a82006be4c 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -371,12 +371,43 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, int nloc = nall_real - nghost_real; int nframes = 1; - // Build spin tensor for real atoms using bkw_map - std::vector dspin(static_cast(nall_real) * 3); - for (int ii = 0; ii < nall_real; ++ii) { + // Phantom-atom padding for the empty-subdomain corner case + // (``nloc_real == 0``). Multi-rank spin MD can land a rank with zero + // real local atoms when atoms migrate to other subdomains. The + // with-comm AOTI artifact, traced with ``nloc_min=1`` and lowered by + // inductor with an even stricter ``nloc >= 2`` runtime-check + // (silently bypassed because ``AOTI_RUNTIME_CHECK_INPUTS`` is unset by + // default), then SIGFPEs at runtime with an "integer divide by zero" + // inside inductor-generated shape arithmetic that uses ``nloc`` as a + // divisor. The failure is intermittent because inductor re-codegens + // across runs and only some compiles emit the offending divide. + // + // Fix: prepend two phantom atoms with no neighbours so the AOTI graph + // runs with ``nloc == 2``. The phantoms have an empty nlist row and + // therefore contribute zero atomic energy / force / virial, preserving + // the physically-correct "this rank has no real atoms" semantics. + // ``nlocal`` in the comm tensors is set to ``2`` so border_op writes + // received ghost features past the phantom slots; outputs are stripped + // of the phantom prefix before being scattered back to LAMMPS atoms + // via ``select_map``. + const int phantom_n = (nloc_real == 0 && nall_real > 0) ? 2 : 0; + if (phantom_n > 0) { + dcoord.insert(dcoord.begin(), static_cast(phantom_n) * 3, + static_cast(0)); + datype.insert(datype.begin(), static_cast(phantom_n), 0); + nall_real += phantom_n; + nloc_real = phantom_n; + nloc = nall_real - nghost_real; + } + + // Build spin tensor for real atoms using bkw_map (skip phantom prefix + // which keeps zero spin). + std::vector dspin(static_cast(nall_real) * 3, + static_cast(0)); + for (int ii = phantom_n; ii < nall_real; ++ii) { for (int dd = 0; dd < 3; ++dd) { dspin[static_cast(ii) * 3 + dd] = - spin[static_cast(bkw_map[ii]) * 3 + dd]; + spin[static_cast(bkw_map[ii - phantom_n]) * 3 + dd]; } } @@ -445,11 +476,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - // Rebuild mapping tensor + // Rebuild mapping tensor. Phantom slots (when phantom_n > 0) get + // identity entries — they index into their own row and never appear + // in any other atom's nlist (their nlist rows are all -1 below). if (lmp_list.mapping) { std::vector mapping(nall_real); - for (int ii = 0; ii < nall_real; ii++) { - mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]; + for (int ii = 0; ii < phantom_n; ii++) { + mapping[ii] = ii; + } + for (int ii = phantom_n; ii < nall_real; ii++) { + mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii - phantom_n]]]; } mapping_tensor = torch::from_blob(mapping.data(), {1, nall_real}, int_option) @@ -472,8 +508,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, } // Flatten raw nlist — the .pt2 model sorts by distance on-device. + // Phantom rows (all -1) are prepended below so the AOTI graph sees + // nloc == phantom_n + nloc_real_orig instead of 0. firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); + if (phantom_n > 0) { + auto phantom_rows = torch::full( + {1, phantom_n, nnei}, static_cast(-1), + torch::TensorOptions().dtype(torch::kInt64).device(device)); + firstneigh_tensor = torch::cat({phantom_rows, firstneigh_tensor}, 1); + } } // Build fparam/aparam tensors @@ -566,6 +610,23 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, ener.assign(flat_energy_.data_ptr(), flat_energy_.data_ptr() + flat_energy_.numel()); + // Zero the reduced energy on an empty rank. Phantoms have constant + // atomic outputs (per-type bias + zero-neighbour MLP) that flow into + // ``energy_redu`` -- and on the spin path the SpinModel doubles atoms + // so the bias contribution appears for both real and spin phantom + // halves; subtracting only the real-half exposed by + // ``output_map["energy"]`` after the ``[:, :nloc]`` slice leaves the + // spin-half leaking into the MPI-reduced LAMMPS total. The physical + // contribution of a rank with no real local atoms is zero by + // definition, so just clear ``ener`` directly. + // + // Forces, force_mag, and virial are unaffected because phantom atomic + // outputs are coord-independent (no neighbours) so their derivatives + // are zero -- no analogous correction is needed. + if (phantom_n > 0) { + std::fill(ener.begin(), ener.end(), static_cast(0)); + } + // Extract force: energy_derv_r (nf, nall, 1, 3) -> (nf, nall, 3) torch::Tensor force_tensor = output_map["energy_derv_r"].squeeze(-2).view({-1}).to(floatType); @@ -588,6 +649,17 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, virial.assign(cpu_virial_.data_ptr(), cpu_virial_.data_ptr() + cpu_virial_.numel()); + // Strip the phantom prefix (see phantom-atom padding comment near + // ``select_real_atoms_coord``) so the ``bkw_map`` lookup below sees + // only the real / ghost atoms it was built for. The phantom slots + // carry zero forces because their nlist rows were all -1 — they + // produce no neighbour contributions, so dropping them is exact. + if (phantom_n > 0) { + dforce.erase(dforce.begin(), dforce.begin() + phantom_n * 3); + dforce_mag.erase(dforce_mag.begin(), dforce_mag.begin() + phantom_n * 3); + nall_real -= phantom_n; + } + // bkw map: map force from real atoms back to full atom list force.resize(static_cast(nframes) * fwd_map.size() * 3); force_mag.resize(static_cast(nframes) * fwd_map.size() * 3); @@ -612,6 +684,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, cpu_atom_virial_.data_ptr(), cpu_atom_virial_.data_ptr() + cpu_atom_virial_.numel()); + // Strip the phantom prefix from atomic outputs as well (see force + // block above). Phantom slots carry zero atomic energy / virial + // because their nlist rows were all -1. + if (phantom_n > 0) { + datom_energy.erase(datom_energy.begin(), + datom_energy.begin() + phantom_n); + datom_virial.erase(datom_virial.begin(), + datom_virial.begin() + phantom_n * 9); + } + atom_energy.resize(static_cast(nframes) * fwd_map.size()); atom_virial.resize(static_cast(nframes) * fwd_map.size() * 9); select_map(atom_energy, datom_energy, bkw_map, 1, nframes,