diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 1a35196de0..e007ce4fad 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -87,12 +87,11 @@ function: * ``core_compute`` rebuilds a compact, GPU-friendly edge list from the - padded DeePMD neighbor list (``build_edge_list_from_nlist``), with - masked dummy edges appended so the edge tensor has a non-singular - symbolic lower bound (NOTE 10). Edge vectors come from - ``index_select`` on the extended coordinate tensor, which keeps the - gradient path back to coordinates explicit and safe under symbolic - shapes (NOTE 11). + padded DeePMD neighbor list (``build_edge_list_from_nlist``), with a + single masked dummy edge appended so the edge tensor is never empty + (NOTE 10). Edge vectors come from ``index_select`` on the extended + coordinate tensor, which keeps the gradient path back to coordinates + explicit and safe under symbolic shapes (NOTE 11). * The SeZM descriptor consumes the edge list and produces per-atom features. * The fitting network predicts per-atom energy; ``apply_out_stat`` adds @@ -323,20 +322,17 @@ In eval mode we merely detach; no ``create_graph`` is requested, so the compiled kernel never has to build a backward graph. -NOTE 10 -- Tail dummy edges ---------------------------- +NOTE 10 -- Tail dummy edge +-------------------------- -``build_edge_list_from_nlist`` appends two masked edges at the end of -every batch. Real edge compaction happens via +``build_edge_list_from_nlist`` appends exactly one masked edge at the +end of every batch. Real edge compaction happens via ``torch.nonzero(valid_mask)``, whose output length is data-dependent and can be zero in sparse or single-type systems. make_fx cannot trace an "if n_edges == 0: skip" branch symbolically; without the dummy it would fall back to concrete shape specialization and break -``dynamic=True``. A pair of dummy slots also gives Inductor's batched -matmul lowering a static ``E >= 2`` edge-axis bound, avoiding -data-dependent layout guards on ``E == 1``. Each dummy's ``edge_mask`` -is ``False`` so it contributes exactly zero to every downstream sum or -gather. +``dynamic=True``. The dummy's ``edge_mask`` is ``False`` so it +contributes exactly zero to every downstream sum or gather. NOTE 11 -- ``index_select`` for coordinate gradients ---------------------------------------------------- @@ -447,6 +443,121 @@ _dynamo_cfg.optimize_ddp = False +# --------------------------------------------------------------------------- +# Multi-task compile sharing +# --------------------------------------------------------------------------- +# Maps (structure_key..., training, do_atomic_virial, has_coord_corr) to the +# compiled callable. Tasks whose descriptor AND fitting-net first child have +# the same Python-object identity (after share_params) reuse a single compiled +# graph, avoiding Nx compile-cache OOM and N DDP graph boundaries (NCCL timeout). +_SEZM_COMPILE_CACHE: dict[tuple, Any] = {} + +# Maps structure_key -> task_buf_order so every instance in the same group +# knows which buffers were promoted and in what order. +_SEZM_TASK_BUF_ORDER: dict[tuple[int, ...], tuple[str, ...]] = {} + +# Prefix namespace for promoted buffer names. +_AM_PREFIX = "am/" # atomic_model registered buffer +_FIT_PREFIX = "fit/" # fitting_net registered buffer +_FIT_ATTR_PREFIX = "fit_attr/" # fitting_net plain tensor attribute (not in _buffers) + + +def _sezm_structure_key(model: SeZMModel) -> tuple[int, ...]: + """Return a key that is equal iff two SeZMModel instances can share a compiled graph. + + After ``share_params``, the descriptor and fitting-net module objects + themselves remain *different* Python objects per task; only their + *submodules* (``_modules`` dict entries) are replaced with shared + references. Using ``id(descriptor)`` or ``id(fitting_net)`` would + therefore always differ between tasks and defeat the cache. + + Fix: use the id of the *first named child* of each module. After + ``share_params(level=0)``, those children are the same Python objects + for all tasks in the same structure group, giving matching keys. + + NOTE: only the FIRST child is sampled, assuming "first child shared => + whole module shared" (true for level=0). Under ``share_params(level=1)`` + only ``type_embedding`` is shared; if it is the first child, two tasks + whose other descriptor weights differ would collapse to the same key and + wrongly reuse one compiled graph. If level=1 + compile is ever used, key + on all param ids instead, e.g. ``frozenset(id(p) for p in desc.parameters())``. + """ + try: + desc = model.atomic_model.descriptor + desc_id = 0 + for _, child in desc.named_children(): + desc_id = id(child) + break + if desc_id == 0: + # Descriptor has no named children (unlikely); fall back. + desc_id = id(desc) + except AttributeError: + desc_id = 0 + try: + fitting = model.atomic_model.fitting_net + for _, child in fitting.named_children(): + return (desc_id, id(child)) + return (desc_id, id(fitting)) + except AttributeError: + return (desc_id, id(model)) + + +def _get_sezm_task_buf_names(model: SeZMModel) -> tuple[str, ...]: + """Return the ordered names of per-task buffers to promote as FX placeholders. + + Always promotes: + * ``out_bias``, ``out_std`` on ``atomic_model`` — may be replaced + out-of-place by ``model_change_out_bias``, so the compiled graph must + never bake them as constants. + * ``bias_atom_e`` on the fitting net — task-specific per-type bias that + differs across tasks after ``share_params``. + * ``case_embd`` on the fitting net — task-identity vector used for + multi-task case conditioning; stored as a plain tensor attribute. + """ + names: list[str] = [] + try: + am = model.atomic_model + for bname in ("out_bias", "out_std"): + if am._buffers.get(bname) is not None: + names.append(_AM_PREFIX + bname) + try: + fitting = am.fitting_net + for bname in ("bias_atom_e",): + if fitting._buffers.get(bname) is not None: + names.append(_FIT_PREFIX + bname) + for aname in ("case_embd",): + val = getattr(fitting, aname, None) + if val is not None and torch.is_tensor(val): + names.append(_FIT_ATTR_PREFIX + aname) + except AttributeError: + pass + except AttributeError: + pass + return tuple(names) + + +def _get_sezm_task_buf_vals( + model: SeZMModel, + names: tuple[str, ...], +) -> tuple[torch.Tensor, ...]: + """Return the current tensor values for the given promoted-buffer names.""" + if not names: + return () + am = model.atomic_model + try: + fitting = am.fitting_net + except AttributeError: + fitting = None + vals: list[torch.Tensor] = [] + for name in names: + if name.startswith(_AM_PREFIX): + vals.append(am._buffers[name[len(_AM_PREFIX) :]]) + elif name.startswith(_FIT_PREFIX): + vals.append(fitting._buffers[name[len(_FIT_PREFIX) :]]) # type: ignore[union-attr] + elif name.startswith(_FIT_ATTR_PREFIX): + vals.append(getattr(fitting, name[len(_FIT_ATTR_PREFIX) :])) + return tuple(vals) + def _parse_optional_env_bool(var_name: str) -> bool | None: """ @@ -490,6 +601,67 @@ def _check_compile_torch_version() -> None: ) +def _is_prime(n: int) -> bool: + """Return True when ``n`` is a prime integer (``n >= 2``).""" + if n < 2: + return False + if n < 4: + return True + if n % 2 == 0: + return False + k = 3 + while k * k <= n: + if n % k == 0: + return False + k += 2 + return True + + +def _next_safe_prime(start: int, forbidden: set[int]) -> int: + """Return the smallest prime ``>= max(start, 5)`` not in ``forbidden``. + + Used by :meth:`SeZMModel.trace_and_compile` to choose collision-free + trace-time sizes for ``nf``, ``nall`` and ``nloc``. Primes ``>= 5`` + avoid every dim PyTorch specializes on (``1`` → broadcasting, + ``2``/``3``/``9`` → Cartesian / virial / charge_spin literals baked + into model code) and guarantee distinct values, which suppresses + make_fx's duck-shape unification without needing the + ``ShapeEnv(duck_shape=False)`` patch. + """ + n = max(start, 5) + while not _is_prime(n) or n in forbidden: + n += 1 + return n + + +def _trace_pad_dim(t: torch.Tensor, dim: int, target: int) -> torch.Tensor: + """Pad or trim ``t`` along ``dim`` so ``t.shape[dim] == target``. + + Padding duplicates the last slice along ``dim``; trimming drops + trailing slices. Used to coerce real-data trace inputs into the + prime-numbered shapes chosen by :func:`_next_safe_prime`. + + Duplicating the last slice preserves valid index values inside + index-bearing tensors (``nlist`` neighbor indices, ``mapping`` + extended-to-local indices) because the duplicated row reuses the + previously-valid row's values. Trimming likewise never invalidates + indices. Only shapes flow downstream during ``make_fx`` tracing, + so the exact replicated/trimmed values do not affect the FX graph. + """ + cur = int(t.shape[dim]) + if cur == target: + return t + if cur > target: + sl: list[slice] = [slice(None)] * t.ndim + sl[dim] = slice(None, target) + return t[tuple(sl)] + sl = [slice(None)] * t.ndim + sl[dim] = slice(-1, None) + last = t[tuple(sl)] + repeats = target - cur + return torch.cat([t, *([last] * repeats)], dim=dim) + + def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: """Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors. @@ -624,6 +796,9 @@ def __init__( # compile products instead of evicting the other mode. object.__setattr__(self, "compiled_core_compute_cache", {}) object.__setattr__(self, "compiled_dens_compute", None) + # Maps cache_key -> task_buf_order for this instance so forward() + # knows which buffers to pass and in what order. + object.__setattr__(self, "_task_buf_order_cache", {}) # Training follows `use_compile`. Evaluation/inference reads # `DP_COMPILE_INFER` at init time and falls back to eager when unset. self._env_use_compile_infer: bool | None = _parse_optional_env_bool( @@ -1010,6 +1185,14 @@ def forward_common_after_nlist( extended_coord_corr=extended_coord_corr, ) compiled_core_compute = self.compiled_core_compute_cache[cache_key] + # Read current values of per-task buffers (optimizer steps + # update them in-place; out-of-place replacements from + # model_change_out_bias are captured because we read fresh + # each call rather than caching the values at compile time). + _task_buf_vals = _get_sezm_task_buf_vals( + self, + getattr(self, "_task_buf_order_cache", {}).get(cache_key, ()), + ) with nvtx_range("SeZM/core_compute"): if extended_coord_corr is None: model_predict_lower = compiled_core_compute( @@ -1020,6 +1203,7 @@ def forward_common_after_nlist( fp, ap, charge_spin, + *_task_buf_vals, ) else: model_predict_lower = compiled_core_compute( @@ -1031,6 +1215,7 @@ def forward_common_after_nlist( ap, charge_spin, extended_coord_corr, + *_task_buf_vals, ) if ( self._core_compute_pending_compile_t0 is not None @@ -1524,6 +1709,31 @@ def trace_and_compile( mode = "train" if self.training else "eval" has_coord_corr = extended_coord_corr is not None + _compile_t0 = time.perf_counter() + + # --- Check module-level shared cache first --- + # Tasks sharing the same descriptor+fitting structure (after share_params) + # should share one compiled graph. If a sibling task already compiled, + # populate this instance's per-instance caches and return immediately. + structure_key = _sezm_structure_key(self) + cache_key = (bool(self.training), bool(do_atomic_virial), has_coord_corr) + full_cache_key = structure_key + cache_key + if full_cache_key in _SEZM_COMPILE_CACHE: + self.compiled_core_compute_cache[cache_key] = _SEZM_COMPILE_CACHE[ + full_cache_key + ] + self._task_buf_order_cache[cache_key] = _SEZM_TASK_BUF_ORDER.get( + structure_key, () + ) + log.info( + "SeZM: reusing shared compiled graph " + "(mode=%s, atomic_virial=%s, coord_corr=%s)", + mode, + do_atomic_virial, + has_coord_corr, + ) + return + log.info( "SeZM: start tracing and compiling " "(mode=%s, atomic_virial=%s, coord_corr=%s)", @@ -1531,7 +1741,71 @@ def trace_and_compile( do_atomic_virial, has_coord_corr, ) - _compile_t0 = time.perf_counter() + + # --- Detect per-task buffers to promote as FX placeholders --- + # These buffers differ across tasks in the same structure group (they are + # NOT shared by share_params) or may be replaced out-of-place after + # compilation. Passing them as explicit arguments makes the compiled + # graph reusable across all tasks in the group. + task_buf_names = _get_sezm_task_buf_names(self) + task_buf_vals_trace = _get_sezm_task_buf_vals(self, task_buf_names) + + # Resolve module references once for the buffer-patching closures. + _am_patch = self.atomic_model + try: + _fitting_patch: torch.nn.Module | None = _am_patch.fitting_net + except AttributeError: + _fitting_patch = None + + def _patch_task_bufs( + vals: tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor | None]: + """Temporarily replace model buffers/attrs with FX proxy tensors. + + Executed at trace time inside compute_fn. make_fx records the + proxy tensors as placeholder nodes, so the compiled graph reads them + as live inputs rather than baked-in constants. The ``finally`` + block in compute_fn always calls ``_restore_task_bufs`` to leave + the model in its original state after tracing. + """ + saved: dict[str, torch.Tensor | None] = {} + for name, val in zip(task_buf_names, vals): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + saved[name] = _am_patch._buffers.get(actual) + _am_patch._buffers[actual] = val + elif name.startswith(_FIT_PREFIX): + actual = name[len(_FIT_PREFIX) :] + saved[name] = ( + _fitting_patch._buffers.get(actual) + if _fitting_patch is not None + else None + ) + if _fitting_patch is not None: + _fitting_patch._buffers[actual] = val + elif name.startswith(_FIT_ATTR_PREFIX): + actual = name[len(_FIT_ATTR_PREFIX) :] + saved[name] = getattr(_fitting_patch, actual, None) + if _fitting_patch is not None: + setattr(_fitting_patch, actual, val) + return saved + + def _restore_task_bufs( + saved: dict[str, torch.Tensor | None], + ) -> None: + """Restore original model buffers/attrs after tracing.""" + for name, orig in saved.items(): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + _am_patch._buffers[actual] = orig + elif name.startswith(_FIT_PREFIX): + actual = name[len(_FIT_PREFIX) :] + if _fitting_patch is not None: + _fitting_patch._buffers[actual] = orig + elif name.startswith(_FIT_ATTR_PREFIX): + actual = name[len(_FIT_ATTR_PREFIX) :] + if _fitting_patch is not None: + setattr(_fitting_patch, actual, orig) need_coord_grad = self.do_grad_r() or self.do_grad_c() @@ -1552,6 +1826,13 @@ def _prepare_coord_for_trace(coord: torch.Tensor) -> torch.Tensor: else: return coord.detach() + # NOTE: compute_fn accepts *task_buf_vals after the fixed tensor args. + # make_fx treats each element as a separate placeholder so the compiled + # graph reads them as live inputs every call — not baked-in constants. + # The buffer-patching trick: at trace time the proxy tensors are written + # into _buffers / __dict__ so that downstream code (apply_out_stat, + # fitting_net.forward) reads the proxies and the ops are recorded in the + # FX graph. The finally block restores original state unconditionally. if extended_coord_corr is None: def compute_fn( @@ -1562,21 +1843,27 @@ def compute_fn( fp: torch.Tensor, ap: torch.Tensor, charge_spin: torch.Tensor, + *task_buf_vals: torch.Tensor, ) -> dict[str, torch.Tensor]: - return self.core_compute( - _prepare_coord_for_trace(extended_coord), - extended_atype, - nlist, - mapping=mapping, - fparam=fp, - aparam=ap, - charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, - extra_nlist_sort=self.need_sorted_nlist_for_lower(), - ) + _saved = _patch_task_bufs(task_buf_vals) + try: + return self.core_compute( + _prepare_coord_for_trace(extended_coord), + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + finally: + _restore_task_bufs(_saved) + else: - def compute_fn( + def compute_fn( # type: ignore[misc] extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, @@ -1585,47 +1872,112 @@ def compute_fn( ap: torch.Tensor, charge_spin: torch.Tensor, extended_coord_corr: torch.Tensor, + *task_buf_vals: torch.Tensor, ) -> dict[str, torch.Tensor]: # NOTE: Spin virial uses a coordinate correction derived from the # virtual-atom displacement. Keeping it as a tensor input lets the # compiled graph stay reusable across frames. - return self.core_compute( - _prepare_coord_for_trace(extended_coord), - extended_atype, - nlist, - mapping=mapping, - fparam=fp, - aparam=ap, - charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, - extra_nlist_sort=self.need_sorted_nlist_for_lower(), - extended_coord_corr=extended_coord_corr, - ) + _saved = _patch_task_bufs(task_buf_vals) + try: + return self.core_compute( + _prepare_coord_for_trace(extended_coord), + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + extended_coord_corr=extended_coord_corr, + ) + finally: + _restore_task_bufs(_saved) - # NOTE: Always trace with a fixed batch size that is free of known - # symbolic-shape collisions. + # NOTE: Choose trace shapes that are pairwise-distinct primes >= 5. + # + # ``make_fx(tracing_mode="symbolic")`` introduces a sympy symbol per + # input dim. Two failure modes follow if those dims accidentally + # match each other or hit a PyTorch-internal "special" value: # - # make_fx(tracing_mode="symbolic") replaces shapes with sympy - # symbols, but the moment a symbolic dim ends up equal to a - # *concrete* dim elsewhere in the same tensor it collapses into - # a constant and the graph specialises on that batch size. Known - # reserved dimensions include 1 (specialisation), 2 (charge/spin - # width), 3 (Cartesian coordinates), and 9 (virial tensor). Any - # of those collisions forces - # ``torch.compile(dynamic=True)`` to reject later batches whose - # nf differs from the traced constant. + # * Duck-shape unification: two input dims that share a concrete + # value at trace time get the SAME sympy symbol, baking an + # equality (``nloc == ntypes``, ``nloc == nall``, ...) the + # compiled graph will violate on later batches. + # * Size specialization: dims equal to ``1`` are baked as literal + # ``1`` regardless of duck-shape; values ``2``/``3``/``9`` are + # commonly literals inside the model (charge/spin width, + # Cartesian, virial) and may be unified with input symbols by + # ShapeEnv even with duck-shape off. # - # If a future code change introduces a new explicit dimension of - # this size and compile starts failing with a similar shape - # mismatch, change this constant accordingly. - trace_nf = 5 - coord_for_trace = extended_coord[:1].repeat(trace_nf, 1, 1) - atype_for_trace = extended_atype[:1].repeat(trace_nf, 1) - nlist_for_trace = nlist[:1].repeat(trace_nf, 1, 1) - mapping_for_trace = mapping[:1].repeat(trace_nf, 1) - fp_for_trace = fp[:1].repeat(trace_nf, 1) - ap_for_trace = ap[:1].repeat(trace_nf, 1, 1) - charge_spin_for_trace = charge_spin[:1].repeat(trace_nf, 1) + # Picking pairwise-distinct primes ``>= 5`` for ``nf``, ``nall``, + # ``nloc`` rules out both failure modes in one stroke: no two + # symbols can fuse (distinct values), and no symbol can hit a + # special literal (``5+`` primes skip ``1``/``2``/``3``/``9``). + # ``nsel``, ``dim_fparam``, ``dim_aparam`` and ``dim_chg_spin`` are + # contractually fixed by the model and added to the forbidden set + # so the chosen primes never collide with them either. + _forbidden: set[int] = {1, 2, 3, 9} + for _tbv in task_buf_vals_trace: + for _d in _tbv.shape: + if _d > 1: + _forbidden.add(int(_d)) + # Model-contracted dims kept at their real values (changing them + # would break the model's own assertions about ``sel``, fparam / + # aparam widths, charge_spin dim). Add to forbidden so primes + # picked for free dims do not collide. + _nsel_real = int(nlist.shape[2]) + _dim_fp = int(fp.shape[1]) + _dim_ap = int(ap.shape[2]) + _dim_cs = int(charge_spin.shape[1]) + for _d in (_nsel_real, _dim_fp, _dim_ap, _dim_cs): + if _d > 1: + _forbidden.add(_d) + # Pick primes in physical order ``nf < nloc < nall``. The order + # ``trace_nloc < trace_nall`` matters: the model slices + # ``extended_atype[:, :nloc]`` to get local atoms; if + # ``trace_nloc > trace_nall`` the slice silently truncates at + # trace time, breaking the captured symbolic shape relation + # ``atype.shape[1] == nloc``. + trace_nf = _next_safe_prime(5, _forbidden) + _forbidden.add(trace_nf) + trace_nloc = _next_safe_prime(trace_nf + 1, _forbidden) + _forbidden.add(trace_nloc) + trace_nall = _next_safe_prime(trace_nloc + 1, _forbidden) + + # Build trace inputs by padding/trimming real-data tensors into + # the chosen prime shapes. ``_trace_pad_dim`` duplicates the + # last slice when padding so index-bearing tensors (``nlist`` + # neighbor indices, ``mapping`` extended-to-local indices) keep + # valid values -- the duplicated row references the same atoms + # the previous row referenced. + coord_for_trace = _trace_pad_dim(extended_coord[:1], 0, trace_nf) + coord_for_trace = _trace_pad_dim(coord_for_trace, 1, trace_nall) + atype_for_trace = _trace_pad_dim(extended_atype[:1], 0, trace_nf) + atype_for_trace = _trace_pad_dim(atype_for_trace, 1, trace_nall) + nlist_for_trace = _trace_pad_dim(nlist[:1], 0, trace_nf) + nlist_for_trace = _trace_pad_dim(nlist_for_trace, 1, trace_nloc) + # Real nlist values are in ``[-1, real_nall)`` (``-1`` marks + # padded slots, non-negative entries index into extended_coord). + # After trimming ``nall`` down to ``trace_nall`` some of those + # values can exceed ``trace_nall``, which would produce + # out-of-range gather indices in ``coord_flat.index_select(0, + # src_ext)`` during the trace pass. Clamp the upper bound to + # ``trace_nall - 1`` (the ``-1`` padding stays untouched since + # clamp only caps the high side). + nlist_for_trace = torch.clamp(nlist_for_trace, max=trace_nall - 1) + mapping_for_trace = _trace_pad_dim(mapping[:1], 0, trace_nf) + mapping_for_trace = _trace_pad_dim(mapping_for_trace, 1, trace_nall) + # Real mapping values are in ``[0, real_nloc)``. If + # ``trace_nloc < real_nloc`` they can exceed ``trace_nloc`` and + # silently propagate into ``src_local`` (used as a local-atom + # index downstream). Clamp to ``trace_nloc - 1``. + mapping_for_trace = torch.clamp(mapping_for_trace, min=0, max=trace_nloc - 1) + fp_for_trace = _trace_pad_dim(fp[:1], 0, trace_nf) + ap_for_trace = _trace_pad_dim(ap[:1], 0, trace_nf) + ap_for_trace = _trace_pad_dim(ap_for_trace, 1, trace_nloc) + charge_spin_for_trace = _trace_pad_dim(charge_spin[:1], 0, trace_nf) + trace_args = [ coord_for_trace, atype_for_trace, @@ -1636,7 +1988,14 @@ def compute_fn( charge_spin_for_trace, ] if extended_coord_corr is not None: - trace_args.append(extended_coord_corr[:1].repeat(trace_nf, 1, 1)) + corr_for_trace = _trace_pad_dim(extended_coord_corr[:1], 0, trace_nf) + corr_for_trace = _trace_pad_dim(corr_for_trace, 1, trace_nall) + trace_args.append(corr_for_trace) + # Append task-buffer values last so they map to the *task_buf_vals + # varargs in compute_fn. Their shapes are static (they don't vary + # batch-to-batch), so passing the actual tensors is correct; make_fx + # will create one placeholder per element. + trace_args.extend(task_buf_vals_trace) # NOTE: Decompose ``silu_backward`` into primitive ops. # PyTorch ships forward and first-order backward for SiLU but no @@ -1735,19 +2094,42 @@ def compute_fn( # ``(training, do_atomic_virial, has_coord_corr)`` so that distinct # graph topologies coexist without evicting each other on every # ``model.eval()`` / ``model.train()`` switch. - cache_key = (bool(self.training), bool(do_atomic_virial), has_coord_corr) # NOTE: ``dynamic=True`` emits a single kernel per traced # shape symbol, so changes in ``nframes``, ``nall`` or edge # count do not trigger recompiles; and the option dict above # disables every Inductor/Triton feature that has ever # interacted badly with ``make_fx`` + double backward in # this project. - self.compiled_core_compute_cache[cache_key] = torch.compile( + compiled = torch.compile( traced, backend="inductor", dynamic=True, options=compile_options, ) + # Populate both per-instance and module-level shared caches. + # The shared cache (_SEZM_COMPILE_CACHE) lets a second task with the + # same structure key skip re-tracing and re-compiling entirely. + self.compiled_core_compute_cache[cache_key] = compiled + self._task_buf_order_cache[cache_key] = task_buf_names + _SEZM_COMPILE_CACHE[full_cache_key] = compiled + _SEZM_TASK_BUF_ORDER[structure_key] = task_buf_names + # NOTE: No dist.barrier() here. + # The barrier premise is that all ranks reach trace_and_compile + # simultaneously. That is FALSE in several trainer code paths: + # + # 1. compute_or_load_stat (training.py:417) runs on rank 0 only. + # Rank 0 compiles → calls barrier → the other N-1 ranks are not + # inside trace_and_compile at that moment → deadlock. + # + # 2. Validation at disp_freq is rank-0-only inside the rank guard; + # if DP_COMPILE_INFER is set, same deadlock. + # + # Instead we rely on compilation being symmetric during the DDP + # training loop itself: all ranks pick the same task per step (same + # random seed), so they all hit trace_and_compile for the same task + # at the same step. The compile-time gap between ranks is on the + # order of seconds while the NCCL default timeout is 30 minutes, + # so no barrier is necessary for the training-loop case. # torch.compile is lazy; the "finished" log is emitted after the # first call triggers Inductor lowering (see forward_common). # ``pending_key`` pairs with ``pending_t0`` so the log is only @@ -1778,7 +2160,8 @@ def compile_dens(self) -> None: "epilogue_fusion": False, "triton.cudagraphs": False, "shape_padding": True, - "max_fusion_size": 64, + "max_fusion_size": 8, + "triton.persistent_reductions": False, }, ), ) @@ -1893,7 +2276,6 @@ def fn( charge_spin_, ) - trace_inputs = (extended_coord, extended_atype, nlist, mapping, fparam, aparam) if self.get_dim_chg_spin() > 0: charge_spin = self.convert_charge_spin( charge_spin, @@ -1901,7 +2283,18 @@ def fn( dtype=extended_coord.dtype, device=extended_coord.device, ) - trace_inputs = (*trace_inputs, charge_spin) + # Always include the charge_spin slot (possibly None) so the traced + # module's forward signature matches the 7-tuple the freeze pipeline + # passes at runtime, regardless of whether the model is conditioned. + trace_inputs = ( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, + ) return self._trace_lower_exportable( fn, @@ -1960,10 +2353,9 @@ def build_edge_list_from_nlist( Build a compact edge list from DeePMD padded neighbor list. Edge vectors are computed via ``index_select`` on ``extended_coord`` - so they remain differentiable w.r.t. the input coordinates. Two - masked dummy edges are always appended to avoid data-dependent empty-edge - branches that ``make_fx`` cannot trace and singular edge-axis guards - in Inductor's batched matmul lowering. + so they remain differentiable w.r.t. the input coordinates. One + masked dummy edge is always appended to avoid data-dependent empty-edge + branches that ``make_fx`` cannot trace. Parameters ---------- @@ -1977,14 +2369,13 @@ def build_edge_list_from_nlist( Returns ------- edge_index - Edge indices with shape (2, E+2) where E is valid edge count. + Edge indices with shape (2, E+1) where E is valid edge count. edge_vec - Edge vectors with shape (E+2, 3). + Edge vectors with shape (E+1, 3). edge_mask - Boolean mask with shape (E+2). The trailing elements are ``False``. + Boolean mask with shape (E+1,). The trailing element is ``False``. """ nf, nloc, nsel = nlist.shape - n_actual = nf * nloc device = extended_coord.device nall = extended_coord.shape[1] descriptor_model = self.atomic_model.descriptor @@ -2002,12 +2393,23 @@ def build_edge_list_from_nlist( # ``torch.where(valid_flat, neighbor_flat, 0)`` sanitises padded # ``-1`` entries before indexing so we never hit an out-of-range # gather; the corresponding edges are filtered out below anyway. - dst_actual = torch.arange( - n_actual, device=device, dtype=torch.long - ).repeat_interleave(nsel) + neighbor_flat = nlist.reshape(-1) + # ``dst_actual = arange(N*K) // K`` produces the same value + # sequence as ``arange(N).repeat_interleave(K)`` but its length + # is derived from ``neighbor_flat.shape[0]`` -- a single symbolic + # source shared with the ``torch.where`` below. The previous + # ``arange(nf*nloc).repeat_interleave(nsel)`` chain could + # decouple from ``nlist.numel()`` in the FX graph if any + # upstream code path ever specialized ``nloc`` at trace time; + # deriving from ``neighbor_flat.shape[0]`` makes the equality + # structural and survives any future change in trace-shape + # selection in ``trace_and_compile``. + dst_actual = ( + torch.arange(neighbor_flat.shape[0], device=device, dtype=torch.long) + // nsel + ) f_idx = dst_actual // nloc dst_local = dst_actual % nloc - neighbor_flat = nlist.reshape(-1) valid_flat = neighbor_flat >= 0 neighbor_safe = torch.where( valid_flat, neighbor_flat, torch.zeros_like(neighbor_flat) @@ -2033,22 +2435,19 @@ def build_edge_list_from_nlist( valid_idx = torch.nonzero(edge_mask_actual, as_tuple=False).flatten() - # === Step 3. Compact edges + append masked dummies === - # NOTE: Always append two masked dummy edges. + # === Step 3. Compact edges + append one masked dummy === + # NOTE: Always append exactly one masked dummy edge. # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent # number of valid edges, which can be zero on sparse or # single-type systems. make_fx cannot trace an # ``if n_edges == 0: skip`` branch symbolically; without the # dummy it would fall back to concrete shape specialisation and - # break ``torch.compile(dynamic=True)`` for later batches. Two - # dummy edges keep the symbolic edge axis statically above one, - # which avoids Inductor bmm layout guards on ``E == 1``. Each + # break ``torch.compile(dynamic=True)`` for later batches. The # dummy edge copies entry 0 (any in-range index is fine) and # carries ``edge_mask=False`` so every downstream sum, gather # or scatter ignores it. - dummy_count = 2 padded_idx = torch.cat( - [valid_idx, torch.zeros(dummy_count, dtype=torch.long, device=device)] + [valid_idx, torch.zeros(1, dtype=torch.long, device=device)] ) src_sel = src_actual.index_select(0, padded_idx) dst_sel = dst_actual.index_select(0, padded_idx) @@ -2057,7 +2456,7 @@ def build_edge_list_from_nlist( edge_mask = torch.cat( [ torch.ones(valid_idx.shape[0], dtype=torch.bool, device=device), - torch.zeros(dummy_count, dtype=torch.bool, device=device), + torch.zeros(1, dtype=torch.bool, device=device), ] ) return edge_index, edge_vec_sel, edge_mask diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 0e682ca803..b3ebcaf2ff 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -613,10 +613,10 @@ def test_fixed_edge_geometry_matches_standard_cache(self) -> None: wigner_calc=descriptor.wigner_calc, ) - # build_edge_list_from_nlist appends masked dummy edges; + # build_edge_list_from_nlist appends exactly one masked dummy edge; # compare only the real edges before the padded tail. n_real = cache_std.src.shape[0] - self.assertEqual(edge_mask.shape[0] - n_real, 2) + self.assertEqual(edge_mask.shape[0] - n_real, 1) self.assertFalse(edge_mask[n_real:].any().item()) self.assertTrue(torch.equal(cache_std.src, cache_sparse.src[:n_real])) self.assertTrue(torch.equal(cache_std.dst, cache_sparse.dst[:n_real])) @@ -898,13 +898,15 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: msg=f"multitask force mismatch at {branch}", ) - # === Step 3. Each compiled branch owns its own compile cache; the - # shared descriptor weights must not collapse them into one. - # Step 2 ran every branch in training mode with the default + # === Step 3. Each branch keeps its own per-instance cache dict, but + # branches that share descriptor + fitting (same Python-object + # identity after share_params) reuse a single compiled callable via + # the module-level ``_SEZM_COMPILE_CACHE``. This avoids the + # N x compile-cache OOM / N DDP graph boundary cost on multitask + # runs. Step 2 ran every branch in training mode with the default # ``do_atomic_virial=False`` and no coordinate correction, so each - # per-branch cache dict - # should hold exactly that one slot, and the compiled callables - # at that slot must be distinct across branches. === + # per-branch cache should hold exactly that one slot, and the + # compiled callable at that slot must be the *same* object. === cache1 = wrapper_cmp.model["water_1"].compiled_core_compute_cache cache2 = wrapper_cmp.model["water_2"].compiled_core_compute_cache self.assertIsNot(cache1, cache2) @@ -915,7 +917,7 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: c2 = cache2[train_key] self.assertIsNotNone(c1) self.assertIsNotNone(c2) - self.assertIsNot(c1, c2) + self.assertIs(c1, c2) # === Step 4. Per-task case embedding must differentiate outputs. === out_e1 = wrapper_eager.model["water_1"](coord, atype, box=box)